Keep side effectful ops in grappler function items

PiperOrigin-RevId: 221653198
This commit is contained in:
Eugene Zhulenev 2018-11-15 11:01:24 -08:00 committed by TensorFlower Gardener
parent 1ffa8477eb
commit 76d204f387
4 changed files with 43 additions and 14 deletions

View File

@ -571,6 +571,10 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
if (node.op().find("Queue") != string::npos) {
return false;
}
// Sending a tensor via a network is a side effect.
if (IsSend(node)) {
return false;
}
return !ModifiesInputsInPlace(node);
}

View File

@ -347,12 +347,6 @@ GrapplerFunctionItem::GrapplerFunctionItem(
fetch.push_back(output_tensor);
}
}
// Stateful and Send (it's not stateful) nodes must be preserved in the graph.
for (const NodeDef& node : graph.node()) {
if (IsSend(node)) {
keep_ops.push_back(node.name());
}
}
}
const string& GrapplerFunctionItem::description() const { return description_; }
@ -584,8 +578,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node,
&connectivity));
// Stateful and Send nodes must be preserved in a function body
if (registration->op_def.is_stateful() || IsSend(func_def_node)) {
// Ops with side effects must be preserved in a function body.
if (!IsFreeOfSideEffect(func_def_node)) {
keep_nodes.push_back(func_def_node.name());
}
}

View File

@ -142,12 +142,6 @@ class GrapplerFunctionItemInstantiation {
class GrapplerFunctionItem : public GrapplerItem {
public:
GrapplerFunctionItem() = default;
GrapplerFunctionItem(string func_name, string description,
AttrSlice func_attr,
std::vector<InputArgExpansion> input_arg_expansions,
std::vector<OutputArgExpansion> output_arg_expansions,
std::vector<string> keep_nodes, int graph_def_version,
bool is_stateful, GraphDef&& function_body);
const string& description() const;
@ -170,12 +164,22 @@ class GrapplerFunctionItem : public GrapplerItem {
GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
private:
friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&,
const FunctionLibraryDefinition&, int,
GrapplerFunctionItem*);
friend Status ReplaceInputWithConst(const NodeDef&, int,
GrapplerFunctionItem*);
friend Status RemoveUnusedOutputs(
const gtl::FlatSet<int>& active_outputs, GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping);
GrapplerFunctionItem(string func_name, string description,
AttrSlice func_attr,
std::vector<InputArgExpansion> input_arg_expansions,
std::vector<OutputArgExpansion> output_arg_expansions,
std::vector<string> keep_nodes, int graph_def_version,
bool is_stateful, GraphDef&& function_body);
string description_;
AttrSlice func_attr_; // Attributes specific to function definition that
// produced this item (FuncDef.attr field).

View File

@ -576,6 +576,33 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
EXPECT_EQ("two", cast.input(0));
}
TEST_F(FunctionsTest, FromFunctionDefWithSideEffectfulOps) {
const Tensor kOne = test::AsScalar<float>(1.0);
FunctionDef func = FunctionDefHelper::Define(
/* Name */ "SideEffects",
/* Args */ {"x: Ref(float)"},
/* Return values */ {},
/* Attr def */ {},
/* Nodes */
{{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_FLOAT}}},
{{"update"}, "AssignAdd", {"x", "one"}, {{"T", DT_FLOAT}}}});
protobuf::Map<string, AttrValue> func_instantiation_attr;
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
TF_EXPECT_OK(MakeGrapplerFunctionItem(func,
AttrSlice(&func_instantiation_attr),
flib, TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("SideEffects", item.id);
EXPECT_EQ(3, item.function_body().node_size());
EXPECT_EQ(1, item.input_size());
EXPECT_EQ(0, item.output_size());
ASSERT_EQ(1, item.keep_ops.size());
EXPECT_EQ("update", item.keep_ops[0]);
}
TEST_F(FunctionsTest, MakeFunctionDef) {
const Tensor kTwo = test::AsScalar<int64>(2);
FunctionDef func = FunctionDefHelper::Define(