Preserve FunctionDef.arg_attr in GrapplerFunctionItem.
PiperOrigin-RevId: 316498288 Change-Id: I6c3288c725bb281cca17256146c9ec3fd8cec5f0
This commit is contained in:
parent
80b3b4fa9f
commit
67487368bb
tensorflow/core
@ -434,9 +434,11 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
|||||||
// _Arg/Placeholder nodes.
|
// _Arg/Placeholder nodes.
|
||||||
if (absl::StartsWith(attr.first, "_")) {
|
if (absl::StartsWith(attr.first, "_")) {
|
||||||
arg_attrs.mutable_attr()->insert(attr);
|
arg_attrs.mutable_attr()->insert(attr);
|
||||||
} else if (attr.first == "shape") {
|
} else if (attr.first == "shape" && argdef->type() != DT_RESOURCE) {
|
||||||
// Preserve known shapes by moving them to the _output_shapes list.
|
// Preserve known shapes by moving them to the _output_shapes list.
|
||||||
// The _Arg shape function knows how to extract them from there.
|
// The _Arg shape function knows how to extract them from there.
|
||||||
|
// Don't preserve the shape of a resource arg node, which is a scalar
|
||||||
|
// resource handle.
|
||||||
AttrValue value;
|
AttrValue value;
|
||||||
*(value.mutable_list()->add_shape()) = attr.second.shape();
|
*(value.mutable_list()->add_shape()) = attr.second.shape();
|
||||||
arg_attrs.mutable_attr()->insert({"_output_shapes", value});
|
arg_attrs.mutable_attr()->insert({"_output_shapes", value});
|
||||||
|
@ -38,12 +38,14 @@ namespace grappler {
|
|||||||
|
|
||||||
GrapplerFunctionItem::GrapplerFunctionItem(
|
GrapplerFunctionItem::GrapplerFunctionItem(
|
||||||
string func_name, string description, AttrSlice func_attr,
|
string func_name, string description, AttrSlice func_attr,
|
||||||
|
std::vector<const FunctionDef::ArgAttrs*> arg_attr,
|
||||||
std::vector<InputArgInstantiation> input_args,
|
std::vector<InputArgInstantiation> input_args,
|
||||||
std::vector<OutputArgInstantiation> output_args,
|
std::vector<OutputArgInstantiation> output_args,
|
||||||
std::vector<ControlOutput> control_outputs, const int graph_def_version,
|
std::vector<ControlOutput> control_outputs, const int graph_def_version,
|
||||||
const bool is_stateful, GraphDef&& function_body)
|
const bool is_stateful, GraphDef&& function_body)
|
||||||
: description_(std::move(description)),
|
: description_(std::move(description)),
|
||||||
func_attr_(func_attr),
|
func_attr_(func_attr),
|
||||||
|
arg_attr_(std::move(arg_attr)),
|
||||||
input_args_(std::move(input_args)),
|
input_args_(std::move(input_args)),
|
||||||
output_args_(std::move(output_args)),
|
output_args_(std::move(output_args)),
|
||||||
control_outputs_(std::move(control_outputs)),
|
control_outputs_(std::move(control_outputs)),
|
||||||
@ -108,6 +110,11 @@ const std::size_t GrapplerFunctionItem::control_output_size() const {
|
|||||||
|
|
||||||
const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
|
const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
|
||||||
|
|
||||||
|
const std::vector<const FunctionDef::ArgAttrs*>&
|
||||||
|
GrapplerFunctionItem::arg_attr() const {
|
||||||
|
return arg_attr_;
|
||||||
|
}
|
||||||
|
|
||||||
const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
|
const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
|
||||||
|
|
||||||
GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
|
GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
|
||||||
@ -278,12 +285,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
|
|||||||
control_outputs.push_back({control_ret.first, control_ret.second});
|
control_outputs.push_back({control_ret.first, control_ret.second});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<const FunctionDef::ArgAttrs*> arg_attr(inputs.size(), nullptr);
|
||||||
|
for (const auto& attr : func.arg_attr()) {
|
||||||
|
arg_attr.at(attr.first) = &attr.second;
|
||||||
|
}
|
||||||
|
|
||||||
*item = GrapplerFunctionItem(
|
*item = GrapplerFunctionItem(
|
||||||
/*func_name=*/signature.name(),
|
/*func_name=*/signature.name(),
|
||||||
/*description=*/signature.description(),
|
/*description=*/signature.description(),
|
||||||
/*func_attr=*/AttrSlice(&func.attr()), std::move(inputs),
|
/*func_attr=*/AttrSlice(&func.attr()), std::move(arg_attr),
|
||||||
std::move(outputs), std::move(control_outputs), graph_def_version,
|
std::move(inputs), std::move(outputs), std::move(control_outputs),
|
||||||
signature.is_stateful(), std::move(function_body));
|
graph_def_version, signature.is_stateful(), std::move(function_body));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -330,6 +342,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
|
|||||||
}
|
}
|
||||||
|
|
||||||
item->input_args_.erase(item->input_args_.begin() + input_index);
|
item->input_args_.erase(item->input_args_.begin() + input_index);
|
||||||
|
item->arg_attr_.erase(item->arg_attr_.begin() + input_index);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -566,6 +579,14 @@ Status MakeFunctionDef(const GrapplerFunctionItem& item,
|
|||||||
(*func->mutable_attr())[attr_name] = attr_value;
|
(*func->mutable_attr())[attr_name] = attr_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy function arg attributes.
|
||||||
|
for (int i = 0; i < item.arg_attr().size(); ++i) {
|
||||||
|
const auto* attr = item.arg_attr().at(i);
|
||||||
|
if (attr != nullptr) {
|
||||||
|
(*func->mutable_arg_attr())[i] = *attr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Copy function body nodes to the FunctionDef and update input format
|
// Copy function body nodes to the FunctionDef and update input format
|
||||||
for (const NodeDef& func_node : item.function_body().node()) {
|
for (const NodeDef& func_node : item.function_body().node()) {
|
||||||
// Skip original `_Arg` and `_Retval` nodes. If node was converted to some
|
// Skip original `_Arg` and `_Retval` nodes. If node was converted to some
|
||||||
|
@ -76,6 +76,7 @@ class GrapplerFunctionItem : public GrapplerItem {
|
|||||||
const std::size_t control_output_size() const;
|
const std::size_t control_output_size() const;
|
||||||
|
|
||||||
const AttrSlice& func_attr() const;
|
const AttrSlice& func_attr() const;
|
||||||
|
const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const;
|
||||||
const GraphDef& function_body() const;
|
const GraphDef& function_body() const;
|
||||||
GraphDef& mutable_function_body();
|
GraphDef& mutable_function_body();
|
||||||
|
|
||||||
@ -95,6 +96,7 @@ class GrapplerFunctionItem : public GrapplerItem {
|
|||||||
|
|
||||||
GrapplerFunctionItem(string func_name, string description,
|
GrapplerFunctionItem(string func_name, string description,
|
||||||
AttrSlice func_attr,
|
AttrSlice func_attr,
|
||||||
|
std::vector<const FunctionDef::ArgAttrs*> arg_attr,
|
||||||
std::vector<InputArgInstantiation> input_args,
|
std::vector<InputArgInstantiation> input_args,
|
||||||
std::vector<OutputArgInstantiation> output_args,
|
std::vector<OutputArgInstantiation> output_args,
|
||||||
std::vector<ControlOutput> control_outputs,
|
std::vector<ControlOutput> control_outputs,
|
||||||
@ -105,6 +107,9 @@ class GrapplerFunctionItem : public GrapplerItem {
|
|||||||
AttrSlice func_attr_; // Attributes specific to function definition that
|
AttrSlice func_attr_; // Attributes specific to function definition that
|
||||||
// produced this item (FuncDef.attr field).
|
// produced this item (FuncDef.attr field).
|
||||||
|
|
||||||
|
// Attributes of function arguments
|
||||||
|
std::vector<const FunctionDef::ArgAttrs*> arg_attr_;
|
||||||
|
|
||||||
std::vector<InputArgInstantiation> input_args_;
|
std::vector<InputArgInstantiation> input_args_;
|
||||||
std::vector<OutputArgInstantiation> output_args_;
|
std::vector<OutputArgInstantiation> output_args_;
|
||||||
std::vector<ControlOutput> control_outputs_;
|
std::vector<ControlOutput> control_outputs_;
|
||||||
|
@ -523,6 +523,14 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
|
|||||||
{{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
|
{{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Add an attribute to _Arg 0;
|
||||||
|
const uint32 arg_index = 0;
|
||||||
|
const std::pair<string, string> arg_attr_key_and_value = {"_arg_attr", "abc"};
|
||||||
|
FunctionDef::ArgAttrs arg_attr;
|
||||||
|
(*arg_attr.mutable_attr())[arg_attr_key_and_value.first].set_s(
|
||||||
|
arg_attr_key_and_value.second);
|
||||||
|
(*func.mutable_arg_attr())[arg_index] = arg_attr;
|
||||||
|
|
||||||
protobuf::Map<string, AttrValue> func_instantiation_attr;
|
protobuf::Map<string, AttrValue> func_instantiation_attr;
|
||||||
func_instantiation_attr["T"].set_type(DT_FLOAT);
|
func_instantiation_attr["T"].set_type(DT_FLOAT);
|
||||||
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
|
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
|
||||||
@ -541,6 +549,15 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
|
|||||||
EXPECT_EQ("y", specialized.signature().output_arg(0).name());
|
EXPECT_EQ("y", specialized.signature().output_arg(0).name());
|
||||||
EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
|
EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
|
||||||
|
|
||||||
|
EXPECT_EQ(specialized.arg_attr().size(), 1);
|
||||||
|
EXPECT_EQ(specialized.arg_attr().at(arg_index).attr().size(), 1);
|
||||||
|
EXPECT_EQ(specialized.arg_attr()
|
||||||
|
.at(arg_index)
|
||||||
|
.attr()
|
||||||
|
.at(arg_attr_key_and_value.first)
|
||||||
|
.s(),
|
||||||
|
arg_attr_key_and_value.second);
|
||||||
|
|
||||||
// Function body specialized for instantiation types.
|
// Function body specialized for instantiation types.
|
||||||
int count = 0;
|
int count = 0;
|
||||||
for (const NodeDef &node : specialized.node_def()) {
|
for (const NodeDef &node : specialized.node_def()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user