Keep side effectful ops in grappler function items
PiperOrigin-RevId: 221653198
This commit is contained in:
parent
1ffa8477eb
commit
76d204f387
@ -571,6 +571,10 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
|
||||
if (node.op().find("Queue") != string::npos) {
|
||||
return false;
|
||||
}
|
||||
// Sending a tensor via a network is a side effect.
|
||||
if (IsSend(node)) {
|
||||
return false;
|
||||
}
|
||||
return !ModifiesInputsInPlace(node);
|
||||
}
|
||||
|
||||
|
@ -347,12 +347,6 @@ GrapplerFunctionItem::GrapplerFunctionItem(
|
||||
fetch.push_back(output_tensor);
|
||||
}
|
||||
}
|
||||
// Stateful and Send (it's not stateful) nodes must be preserved in the graph.
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
if (IsSend(node)) {
|
||||
keep_ops.push_back(node.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const string& GrapplerFunctionItem::description() const { return description_; }
|
||||
@ -584,8 +578,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
|
||||
TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node,
|
||||
&connectivity));
|
||||
|
||||
// Stateful and Send nodes must be preserved in a function body
|
||||
if (registration->op_def.is_stateful() || IsSend(func_def_node)) {
|
||||
// Ops with side effects must be preserved in a function body.
|
||||
if (!IsFreeOfSideEffect(func_def_node)) {
|
||||
keep_nodes.push_back(func_def_node.name());
|
||||
}
|
||||
}
|
||||
|
@ -142,12 +142,6 @@ class GrapplerFunctionItemInstantiation {
|
||||
class GrapplerFunctionItem : public GrapplerItem {
|
||||
public:
|
||||
GrapplerFunctionItem() = default;
|
||||
GrapplerFunctionItem(string func_name, string description,
|
||||
AttrSlice func_attr,
|
||||
std::vector<InputArgExpansion> input_arg_expansions,
|
||||
std::vector<OutputArgExpansion> output_arg_expansions,
|
||||
std::vector<string> keep_nodes, int graph_def_version,
|
||||
bool is_stateful, GraphDef&& function_body);
|
||||
|
||||
const string& description() const;
|
||||
|
||||
@ -170,12 +164,22 @@ class GrapplerFunctionItem : public GrapplerItem {
|
||||
GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
|
||||
|
||||
private:
|
||||
friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&,
|
||||
const FunctionLibraryDefinition&, int,
|
||||
GrapplerFunctionItem*);
|
||||
friend Status ReplaceInputWithConst(const NodeDef&, int,
|
||||
GrapplerFunctionItem*);
|
||||
friend Status RemoveUnusedOutputs(
|
||||
const gtl::FlatSet<int>& active_outputs, GrapplerFunctionItem* item,
|
||||
std::vector<std::pair<int, int>>* output_mapping);
|
||||
|
||||
GrapplerFunctionItem(string func_name, string description,
|
||||
AttrSlice func_attr,
|
||||
std::vector<InputArgExpansion> input_arg_expansions,
|
||||
std::vector<OutputArgExpansion> output_arg_expansions,
|
||||
std::vector<string> keep_nodes, int graph_def_version,
|
||||
bool is_stateful, GraphDef&& function_body);
|
||||
|
||||
string description_;
|
||||
AttrSlice func_attr_; // Attributes specific to function definition that
|
||||
// produced this item (FuncDef.attr field).
|
||||
|
@ -576,6 +576,33 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
|
||||
EXPECT_EQ("two", cast.input(0));
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, FromFunctionDefWithSideEffectfulOps) {
|
||||
const Tensor kOne = test::AsScalar<float>(1.0);
|
||||
FunctionDef func = FunctionDefHelper::Define(
|
||||
/* Name */ "SideEffects",
|
||||
/* Args */ {"x: Ref(float)"},
|
||||
/* Return values */ {},
|
||||
/* Attr def */ {},
|
||||
/* Nodes */
|
||||
{{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_FLOAT}}},
|
||||
{{"update"}, "AssignAdd", {"x", "one"}, {{"T", DT_FLOAT}}}});
|
||||
|
||||
protobuf::Map<string, AttrValue> func_instantiation_attr;
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
|
||||
|
||||
GrapplerFunctionItem item;
|
||||
TF_EXPECT_OK(MakeGrapplerFunctionItem(func,
|
||||
AttrSlice(&func_instantiation_attr),
|
||||
flib, TF_GRAPH_DEF_VERSION, &item));
|
||||
|
||||
EXPECT_EQ("SideEffects", item.id);
|
||||
EXPECT_EQ(3, item.function_body().node_size());
|
||||
EXPECT_EQ(1, item.input_size());
|
||||
EXPECT_EQ(0, item.output_size());
|
||||
ASSERT_EQ(1, item.keep_ops.size());
|
||||
EXPECT_EQ("update", item.keep_ops[0]);
|
||||
}
|
||||
|
||||
TEST_F(FunctionsTest, MakeFunctionDef) {
|
||||
const Tensor kTwo = test::AsScalar<int64>(2);
|
||||
FunctionDef func = FunctionDefHelper::Define(
|
||||
|
Loading…
Reference in New Issue
Block a user