Preserve FunctionDef.arg_attr in GrapplerFunctionItem.

PiperOrigin-RevId: 316498288
Change-Id: I6c3288c725bb281cca17256146c9ec3fd8cec5f0
This commit is contained in:
Yujing Zhang 2020-06-15 10:48:56 -07:00 committed by TensorFlower Gardener
parent 80b3b4fa9f
commit 67487368bb
4 changed files with 49 additions and 4 deletions

View File

@ -434,9 +434,11 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
// _Arg/Placeholder nodes.
if (absl::StartsWith(attr.first, "_")) {
arg_attrs.mutable_attr()->insert(attr);
} else if (attr.first == "shape") {
} else if (attr.first == "shape" && argdef->type() != DT_RESOURCE) {
// Preserve known shapes by moving them to the _output_shapes list.
// The _Arg shape function knows how to extract them from there.
// Don't preserve the shape of a resource arg node, which is a scalar
// resource handle.
AttrValue value;
*(value.mutable_list()->add_shape()) = attr.second.shape();
arg_attrs.mutable_attr()->insert({"_output_shapes", value});

View File

@ -38,12 +38,14 @@ namespace grappler {
GrapplerFunctionItem::GrapplerFunctionItem(
string func_name, string description, AttrSlice func_attr,
std::vector<const FunctionDef::ArgAttrs*> arg_attr,
std::vector<InputArgInstantiation> input_args,
std::vector<OutputArgInstantiation> output_args,
std::vector<ControlOutput> control_outputs, const int graph_def_version,
const bool is_stateful, GraphDef&& function_body)
: description_(std::move(description)),
func_attr_(func_attr),
arg_attr_(std::move(arg_attr)),
input_args_(std::move(input_args)),
output_args_(std::move(output_args)),
control_outputs_(std::move(control_outputs)),
@ -108,6 +110,11 @@ const std::size_t GrapplerFunctionItem::control_output_size() const {
const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
const std::vector<const FunctionDef::ArgAttrs*>&
GrapplerFunctionItem::arg_attr() const {
return arg_attr_;
}
const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
@ -278,12 +285,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
control_outputs.push_back({control_ret.first, control_ret.second});
}
std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr);
for (const auto& attr : func.arg_attr()) {
arg_attr.at(attr.first) = &attr.second;
}
*item = GrapplerFunctionItem(
/*func_name=*/signature.name(),
/*description=*/signature.description(),
/*func_attr=*/AttrSlice(&func.attr()), std::move(inputs),
std::move(outputs), std::move(control_outputs), graph_def_version,
signature.is_stateful(), std::move(function_body));
/*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr),
std::move(inputs), std::move(outputs), std::move(control_outputs),
graph_def_version, signature.is_stateful(), std::move(function_body));
return Status::OK();
}
@ -330,6 +342,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
}
item->input_args_.erase(item->input_args_.begin() + input_index);
item->arg_attr_.erase(item->arg_attr_.begin() + input_index);
return Status::OK();
}
@ -566,6 +579,14 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
(*func->mutable_attr())[attr_name] = attr_value;
}
// Copy function arg attributes.
for (int i = 0; i < item.arg_attr().size(); ++i) {
const auto* attr = item.arg_attr().at(i);
if (attr != nullptr) {
(*func->mutable_arg_attr())[i] = *attr;
}
}
// Copy function body nodes to the FunctionDef and update input format
for (const NodeDef& func_node : item.function_body().node()) {
// Skip original `_Arg` and `_Retval` nodes. If node was converted to some

View File

@ -76,6 +76,7 @@ class GrapplerFunctionItem : public GrapplerItem {
const std::size_t control_output_size() const;
const AttrSlice& func_attr() const;
const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const;
const GraphDef& function_body() const;
GraphDef& mutable_function_body();
@ -95,6 +96,7 @@ class GrapplerFunctionItem : public GrapplerItem {
GrapplerFunctionItem(string func_name, string description,
AttrSlice func_attr,
std::vector<const FunctionDef::ArgAttrs*> arg_attr,
std::vector<InputArgInstantiation> input_args,
std::vector<OutputArgInstantiation> output_args,
std::vector<ControlOutput> control_outputs,
@ -105,6 +107,9 @@ class GrapplerFunctionItem : public GrapplerItem {
AttrSlice func_attr_; // Attributes specific to function definition that
// produced this item (FuncDef.attr field).
// Attributes of function arguments
std::vector<const FunctionDef::ArgAttrs*> arg_attr_;
std::vector<InputArgInstantiation> input_args_;
std::vector<OutputArgInstantiation> output_args_;
std::vector<ControlOutput> control_outputs_;

View File

@ -523,6 +523,14 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
{{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
});
// Add an attribute to _Arg 0;
const uint32 arg_index = 0;
const std::pair<string, string> arg_attr_key_and_value = {"_arg_attr", "abc"};
FunctionDef::ArgAttrs arg_attr;
(*arg_attr.mutable_attr())[arg_attr_key_and_value.first].set_s(
arg_attr_key_and_value.second);
(*func.mutable_arg_attr())[arg_index] = arg_attr;
protobuf::Map<string, AttrValue> func_instantiation_attr;
func_instantiation_attr["T"].set_type(DT_FLOAT);
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
@ -541,6 +549,15 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
EXPECT_EQ("y", specialized.signature().output_arg(0).name());
EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
EXPECT_EQ(specialized.arg_attr().size(), 1);
EXPECT_EQ(specialized.arg_attr().at(arg_index).attr().size(), 1);
EXPECT_EQ(specialized.arg_attr()
.at(arg_index)
.attr()
.at(arg_attr_key_and_value.first)
.s(),
arg_attr_key_and_value.second);
// Function body specialized for instantiation types.
int count = 0;
for (const NodeDef &node : specialized.node_def()) {