Provide an option to use ApiDef instead of OpGenOverrides when generating C++ API. Also, updating UpdateDocs method to ApiDef to replace names in docs.
PiperOrigin-RevId: 176167953
This commit is contained in:
parent
3cc43816cd
commit
cb12ebe044
@ -421,6 +421,7 @@ tf_cc_test(
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
api_def_srcs = ["//tensorflow/core:base_api_def"],
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
"audio_ops",
|
||||
@ -525,6 +526,9 @@ cc_library_with_android_deps(
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
data = [
|
||||
"//tensorflow/core:base_api_def",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -536,6 +540,29 @@ cc_library_with_android_deps(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cc_op_gen_test",
|
||||
srcs = [
|
||||
"framework/cc_op_gen.cc",
|
||||
"framework/cc_op_gen.h",
|
||||
"framework/cc_op_gen_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc:ops/op_gen_overrides.pbtxt",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:op_gen_overrides_proto_cc",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_op_op_lib",
|
||||
srcs = ["framework/test_op.cc"],
|
||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
@ -385,10 +387,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) {
|
||||
}
|
||||
|
||||
bool HasOptionalAttrs(
|
||||
const OpDef& op_def,
|
||||
const ApiDef& api_def,
|
||||
const std::unordered_map<string, string>& inferred_input_attrs) {
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < api_def.attr_size(); ++i) {
|
||||
const auto& attr(api_def.attr(i));
|
||||
if ((inferred_input_attrs.find(attr.name()) ==
|
||||
inferred_input_attrs.end()) &&
|
||||
attr.has_default_value()) {
|
||||
@ -398,12 +400,21 @@ bool HasOptionalAttrs(
|
||||
return false;
|
||||
}
|
||||
|
||||
const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
|
||||
for (int i = 0; i < api_def.in_arg_size(); ++i) {
|
||||
if (api_def.in_arg(i).name() == name) {
|
||||
return &api_def.in_arg(i);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct OpInfo {
|
||||
// graph_op_def: The OpDef used by the runtime, has the names that
|
||||
// must be used when calling NodeBuilder.
|
||||
// interface_op_def: The OpDef used in the interface in the generated
|
||||
// code, with possibly overridden names and defaults.
|
||||
explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def,
|
||||
explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases);
|
||||
string GetOpAttrStruct() const;
|
||||
string GetConstructorDecl(StringPiece op_name_prefix,
|
||||
@ -423,74 +434,81 @@ struct OpInfo {
|
||||
string comment;
|
||||
|
||||
const OpDef& graph_op_def;
|
||||
const OpDef& op_def;
|
||||
const ApiDef& api_def;
|
||||
const std::vector<string>& aliases;
|
||||
// Map from type attribute to corresponding original argument name.
|
||||
std::unordered_map<string, string> inferred_input_attrs;
|
||||
};
|
||||
|
||||
OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
|
||||
const std::vector<string>& a)
|
||||
: graph_op_def(g_op_def), op_def(i_op_def), aliases(a) {
|
||||
op_name = op_def.name();
|
||||
InferOpAttributes(op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
|
||||
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases)
|
||||
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
|
||||
op_name = api_def.endpoint(0).name();
|
||||
InferOpAttributes(graph_op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
|
||||
arg_types.push_back("const ::tensorflow::Scope&");
|
||||
arg_names.push_back("scope");
|
||||
|
||||
if (op_def.has_deprecation()) {
|
||||
if (!op_def.summary().empty()) {
|
||||
comment = strings::StrCat(op_def.summary(), "\n");
|
||||
if (graph_op_def.has_deprecation()) {
|
||||
if (!api_def.summary().empty()) {
|
||||
comment = strings::StrCat(api_def.summary(), "\n");
|
||||
}
|
||||
strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
|
||||
op_def.deprecation().version(), ":\n",
|
||||
op_def.deprecation().explanation(), ".\n");
|
||||
} else if (op_def.summary().empty()) {
|
||||
graph_op_def.deprecation().version(), ":\n",
|
||||
graph_op_def.deprecation().explanation(), ".\n");
|
||||
} else if (api_def.summary().empty()) {
|
||||
comment = "TODO: add doc.\n";
|
||||
} else {
|
||||
comment = strings::StrCat(op_def.summary(), "\n");
|
||||
comment = strings::StrCat(api_def.summary(), "\n");
|
||||
}
|
||||
if (!op_def.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
|
||||
if (!api_def.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", api_def.description(), "\n");
|
||||
}
|
||||
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
|
||||
|
||||
// Process inputs
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
for (int i = 0; i < api_def.arg_order_size(); ++i) {
|
||||
const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
|
||||
const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
|
||||
arg_types.push_back(strings::StrCat(
|
||||
"::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
|
||||
arg_names.push_back(AvoidCPPKeywords(arg.name()));
|
||||
arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
|
||||
|
||||
// TODO(keveman): Include input type information.
|
||||
StringPiece description = arg.description();
|
||||
StringPiece description = api_def_arg.description();
|
||||
if (!description.empty()) {
|
||||
ConsumeEquals(&description);
|
||||
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ",
|
||||
arg.description(), "\n");
|
||||
strings::StrAppend(&comment, "* ",
|
||||
AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
|
||||
api_def_arg.description(), "\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Process attrs
|
||||
string required_attrs_comment;
|
||||
string optional_attrs_comment;
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
|
||||
// ApiDef attributes must be in the same order as in OpDef since
|
||||
// we initialize ApiDef based on OpDef.
|
||||
const auto& attr(graph_op_def.attr(i));
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
CHECK_EQ(attr.name(), api_def_attr.name());
|
||||
// Skip inferred arguments
|
||||
if (inferred_input_attrs.count(attr.name()) > 0) continue;
|
||||
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
string attr_name = AvoidCPPKeywords(attr.name());
|
||||
string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
|
||||
|
||||
string attr_comment;
|
||||
if (!attr.description().empty()) {
|
||||
if (!api_def_attr.description().empty()) {
|
||||
// TODO(keveman): Word wrap and indent this, to handle multi-line
|
||||
// descriptions.
|
||||
strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
|
||||
attr.description(), "\n");
|
||||
api_def_attr.description(), "\n");
|
||||
}
|
||||
if (attr.has_default_value()) {
|
||||
if (api_def_attr.has_default_value()) {
|
||||
strings::StrAppend(&optional_attrs_comment, attr_comment);
|
||||
} else {
|
||||
strings::StrAppend(&required_attrs_comment, attr_comment);
|
||||
@ -508,44 +526,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
|
||||
}
|
||||
|
||||
// Process outputs
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
const auto& arg = op_def.output_arg(i);
|
||||
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
|
||||
// ApiDef arguments must be in the same order as in OpDef since
|
||||
// we initialize ApiDef based on OpDef.
|
||||
const auto& arg = graph_op_def.output_arg(i);
|
||||
const auto& api_def_arg(api_def.out_arg(i));
|
||||
CHECK_EQ(arg.name(), api_def_arg.name());
|
||||
|
||||
bool is_list = ArgIsList(arg);
|
||||
output_types.push_back(
|
||||
strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
|
||||
output_names.push_back(AvoidCPPKeywords(arg.name()));
|
||||
output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
|
||||
is_list_output.push_back(is_list);
|
||||
}
|
||||
|
||||
strings::StrAppend(&comment, "\nReturns:\n");
|
||||
if (op_def.output_arg_size() == 0) { // No outputs.
|
||||
if (graph_op_def.output_arg_size() == 0) { // No outputs.
|
||||
strings::StrAppend(&comment, "* the created `Operation`\n");
|
||||
} else if (op_def.output_arg_size() == 1) { // One output
|
||||
} else if (graph_op_def.output_arg_size() == 1) { // One output
|
||||
if (is_list_output[0]) {
|
||||
strings::StrAppend(&comment, "* `OutputList`: ");
|
||||
} else {
|
||||
strings::StrAppend(&comment, "* `Output`: ");
|
||||
}
|
||||
if (op_def.output_arg(0).description().empty()) {
|
||||
strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(),
|
||||
if (api_def.out_arg(0).description().empty()) {
|
||||
strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
|
||||
" tensor.\n");
|
||||
} else {
|
||||
// TODO(josh11b): Word wrap this.
|
||||
strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n");
|
||||
strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
|
||||
}
|
||||
} else { // Multiple outputs.
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
|
||||
if (is_list_output[i]) {
|
||||
strings::StrAppend(&comment, "* `OutputList`");
|
||||
} else {
|
||||
strings::StrAppend(&comment, "* `Output`");
|
||||
}
|
||||
strings::StrAppend(&comment, " ", output_names[i]);
|
||||
if (op_def.output_arg(i).description().empty()) {
|
||||
if (api_def.out_arg(i).description().empty()) {
|
||||
strings::StrAppend(&comment, "\n");
|
||||
} else {
|
||||
// TODO(josh11b): Word wrap this.
|
||||
strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(),
|
||||
strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
|
||||
"\n");
|
||||
}
|
||||
}
|
||||
@ -564,19 +587,20 @@ string OpInfo::GetOpAttrStruct() const {
|
||||
string struct_fields;
|
||||
string setters;
|
||||
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
|
||||
const auto& attr(graph_op_def.attr(i));
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
// If attr will be inferred or it doesn't have a default value, don't
|
||||
// add it to the struct.
|
||||
if ((inferred_input_attrs.find(attr.name()) !=
|
||||
inferred_input_attrs.end()) ||
|
||||
!attr.has_default_value()) {
|
||||
!api_def_attr.has_default_value()) {
|
||||
continue;
|
||||
}
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
const string camel_case_name = ToCamelCase(attr.name());
|
||||
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
|
||||
const string suffix =
|
||||
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
|
||||
const string attr_func_def =
|
||||
@ -584,22 +608,25 @@ string OpInfo::GetOpAttrStruct() const {
|
||||
attr_type_name, use_const ? "&" : "");
|
||||
|
||||
string attr_comment;
|
||||
if (!attr.description().empty()) {
|
||||
strings::StrAppend(&attr_comment, attr.description(), "\n\n");
|
||||
if (!api_def_attr.description().empty()) {
|
||||
strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
|
||||
}
|
||||
strings::StrAppend(&attr_comment, "Defaults to ",
|
||||
SummarizeAttrValue(attr.default_value()), "\n");
|
||||
SummarizeAttrValue(api_def_attr.default_value()), "\n");
|
||||
attr_comment = MakeComment(attr_comment, " ");
|
||||
|
||||
strings::StrAppend(&setters, attr_comment);
|
||||
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
|
||||
strings::StrAppend(&setters, " Attrs ret = *this;\n");
|
||||
strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
|
||||
strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(),
|
||||
"_ = x;\n");
|
||||
strings::StrAppend(&setters, " return ret;\n }\n\n");
|
||||
|
||||
strings::StrAppend(
|
||||
&struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ",
|
||||
PrintAttrValue(op_def.name(), attr.default_value()), ";\n");
|
||||
&struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
|
||||
"_ = ",
|
||||
PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
|
||||
";\n");
|
||||
}
|
||||
|
||||
if (struct_fields.empty()) {
|
||||
@ -676,17 +703,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
|
||||
// Add the static functions to set optional attrs
|
||||
if (has_optional_attrs) {
|
||||
strings::StrAppend(&class_decl, "\n");
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
|
||||
const auto& attr(graph_op_def.attr(i));
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
if ((inferred_input_attrs.find(attr.name()) !=
|
||||
inferred_input_attrs.end()) ||
|
||||
!attr.has_default_value()) {
|
||||
!api_def_attr.has_default_value()) {
|
||||
continue;
|
||||
}
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
const string camel_case_name = ToCamelCase(attr.name());
|
||||
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
|
||||
const string suffix =
|
||||
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
|
||||
const string attr_func_def = strings::StrCat(
|
||||
@ -726,11 +754,11 @@ void OpInfo::GetOutput(string* out) const {
|
||||
strings::StrCat("if (!", scope_str, ".ok()) return;");
|
||||
|
||||
// No outputs.
|
||||
if (op_def.output_arg_size() == 0) {
|
||||
if (graph_op_def.output_arg_size() == 0) {
|
||||
strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
|
||||
return;
|
||||
}
|
||||
if (op_def.output_arg_size() == 1) {
|
||||
if (graph_op_def.output_arg_size() == 1) {
|
||||
// One output, no need for NameRangeMap
|
||||
if (is_list_output[0]) {
|
||||
strings::StrAppend(out,
|
||||
@ -752,7 +780,7 @@ void OpInfo::GetOutput(string* out) const {
|
||||
".UpdateStatus(_status_);\n", " return;\n");
|
||||
strings::StrAppend(out, " }\n\n");
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
|
||||
const string arg_range = strings::StrCat(
|
||||
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
|
||||
if (is_list_output[i]) {
|
||||
@ -776,11 +804,13 @@ string OpInfo::GetConstructorBody() const {
|
||||
|
||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::",
|
||||
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(",
|
||||
scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n");
|
||||
for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(graph_op_def.input_arg(i));
|
||||
const auto& api_def_arg(api_def.in_arg(i));
|
||||
strings::StrAppend(
|
||||
&body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
|
||||
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
|
||||
AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
|
||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||
}
|
||||
|
||||
@ -791,19 +821,21 @@ string OpInfo::GetConstructorBody() const {
|
||||
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
|
||||
graph_op_def.name(), "\")\n");
|
||||
const string spaces = " ";
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
|
||||
for (int i = 0; i < api_def.in_arg_size(); ++i) {
|
||||
const auto& arg(api_def.in_arg(i));
|
||||
strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
|
||||
}
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
for (int i = 0; i < api_def.attr_size(); ++i) {
|
||||
const auto& graph_attr(graph_op_def.attr(i));
|
||||
const auto& attr(op_def.attr(i));
|
||||
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
|
||||
const auto& api_def_attr(api_def.attr(i));
|
||||
if (inferred_input_attrs.find(api_def_attr.name()) !=
|
||||
inferred_input_attrs.end()) {
|
||||
continue;
|
||||
}
|
||||
const string attr_name = attr.has_default_value()
|
||||
? strings::StrCat("attrs.", attr.name(), "_")
|
||||
: AvoidCPPKeywords(attr.name());
|
||||
const string attr_name =
|
||||
api_def_attr.has_default_value()
|
||||
? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
|
||||
: AvoidCPPKeywords(api_def_attr.rename_to());
|
||||
strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
|
||||
attr_name, ")\n");
|
||||
}
|
||||
@ -845,10 +877,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const {
|
||||
TF_CHECK_OK(cc->Append(class_def));
|
||||
}
|
||||
|
||||
void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def,
|
||||
void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
|
||||
const std::vector<string>& aliases, WritableFile* h,
|
||||
WritableFile* cc) {
|
||||
OpInfo op_info(graph_op_def, interface_op_def, aliases);
|
||||
OpInfo op_info(graph_op_def, api_def, aliases);
|
||||
|
||||
op_info.WriteClassDecl(h);
|
||||
op_info.WriteClassDef(cc);
|
||||
@ -943,8 +975,9 @@ string MakeInternal(const string& fname) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
const string& dot_cc_fname, const string& overrides_fnames) {
|
||||
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& dot_h_fname, const string& dot_cc_fname,
|
||||
const string& overrides_fnames) {
|
||||
Env* env = Env::Default();
|
||||
|
||||
// Load the override map.
|
||||
@ -984,24 +1017,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
// code depends on it.
|
||||
if (graph_op_def.name() == "Const") continue;
|
||||
|
||||
// Incorporate overrides from override_map.
|
||||
OpDef interface_op_def = graph_op_def;
|
||||
const OpGenOverride* op_override =
|
||||
override_map.ApplyOverride(&interface_op_def);
|
||||
std::vector<string> aliases;
|
||||
if (op_override) {
|
||||
if (op_override->skip()) continue;
|
||||
aliases.assign(op_override->alias().begin(), op_override->alias().end());
|
||||
if (op_override->hide()) {
|
||||
// Write hidden ops to _internal.h and _internal.cc.
|
||||
WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(),
|
||||
internal_cc.get());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
|
||||
|
||||
std::vector<string> aliases;
|
||||
if (api_def->visibility() == ApiDef::SKIP) continue;
|
||||
// First endpoint is canonical, the rest are aliases.
|
||||
for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
|
||||
++endpoint_i) {
|
||||
aliases.push_back(api_def->endpoint(endpoint_i).name());
|
||||
}
|
||||
if (api_def->visibility() == ApiDef::HIDDEN) {
|
||||
// Write hidden ops to _internal.h and _internal.cc.
|
||||
WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
|
||||
internal_cc.get());
|
||||
continue;
|
||||
}
|
||||
// This isn't a hidden op, write it to the main files.
|
||||
WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get());
|
||||
WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
|
||||
}
|
||||
|
||||
FinishFiles(false, h.get(), cc.get(), op_header_guard);
|
||||
|
@ -17,13 +17,15 @@ limitations under the License.
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Result is written to files dot_h and dot_cc.
|
||||
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
const string& dot_cc_fname, const string& overrides_fnames);
|
||||
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
|
||||
const string& dot_h_fname, const string& dot_cc_fname,
|
||||
const string& overrides_fnames);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,7 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -24,10 +28,28 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
|
||||
const std::string& overrides_fnames, bool include_internal) {
|
||||
const std::string& overrides_fnames, bool include_internal,
|
||||
const std::vector<string>& api_def_dirs) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(include_internal, &ops);
|
||||
WriteCCOps(ops, dot_h, dot_cc, overrides_fnames);
|
||||
ApiDefMap api_def_map(ops);
|
||||
if (!api_def_dirs.empty()) {
|
||||
Env* env = Env::Default();
|
||||
// Only load files that correspond to "ops".
|
||||
for (const auto& op : ops.op()) {
|
||||
for (const auto& api_def_dir : api_def_dirs) {
|
||||
const std::string api_def_file_pattern =
|
||||
io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
|
||||
if (env->FileExists(api_def_file_pattern).ok()) {
|
||||
TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
api_def_map.UpdateDocs();
|
||||
|
||||
WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
if (argc != 5) {
|
||||
// TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt
|
||||
// as an argument.
|
||||
if (argc != 6) {
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
|
||||
}
|
||||
fprintf(stderr,
|
||||
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n"
|
||||
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal "
|
||||
"api_def_dirs1,api_def_dir2 ...\n"
|
||||
" include_internal: 1 means include internal ops\n",
|
||||
argv[0]);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool include_internal = tensorflow::StringPiece("1") == argv[4];
|
||||
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal);
|
||||
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
|
||||
argv[5], ",", tensorflow::str_util::SkipEmpty());
|
||||
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal,
|
||||
api_def_dirs);
|
||||
return 0;
|
||||
}
|
||||
|
195
tensorflow/cc/framework/cc_op_gen_test.cc
Normal file
195
tensorflow/cc/framework/cc_op_gen_test.cc
Normal file
@ -0,0 +1,195 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// TODO(annarev): Remove this op_gen_overrides.pbtxt reference.
|
||||
// It is needed only because WriteCCOps takes it as an argument.
|
||||
constexpr char kOverridesFnames[] =
|
||||
"tensorflow/cc/ops/op_gen_overrides.pbtxt";
|
||||
constexpr char kBaseOpDef[] = R"(
|
||||
op {
|
||||
name: "Foo"
|
||||
input_arg {
|
||||
name: "images"
|
||||
description: "Images to process."
|
||||
}
|
||||
input_arg {
|
||||
name: "dim"
|
||||
description: "Description for dim."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Description for output."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
description: "Type for images"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
}
|
||||
}
|
||||
default_value {
|
||||
i: 1
|
||||
}
|
||||
}
|
||||
summary: "Summary for op Foo."
|
||||
description: "Description for op Foo."
|
||||
}
|
||||
)";
|
||||
|
||||
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_TRUE(s.contains(expected))
|
||||
<< "'" << s << "' does not contain '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
|
||||
EXPECT_FALSE(s.contains(expected))
|
||||
<< "'" << s << "' contains '" << expected << "'";
|
||||
}
|
||||
|
||||
void ExpectSubstrOrder(const string& s, const string& before,
|
||||
const string& after) {
|
||||
int before_pos = s.find(before);
|
||||
int after_pos = s.find(after);
|
||||
ASSERT_NE(std::string::npos, before_pos);
|
||||
ASSERT_NE(std::string::npos, after_pos);
|
||||
EXPECT_LT(before_pos, after_pos)
|
||||
<< before << " is not before " << after << " in " << s;
|
||||
}
|
||||
|
||||
// Runs WriteCCOps and stores output in (internal_)cc_file_path and
|
||||
// (internal_)h_file_path.
|
||||
void GenerateCcOpFiles(Env* env, const OpList& ops,
|
||||
const ApiDefMap& api_def_map, string* h_file_text,
|
||||
string* internal_h_file_text) {
|
||||
const string& tmpdir = testing::TmpDir();
|
||||
|
||||
const auto h_file_path = io::JoinPath(tmpdir, "test.h");
|
||||
const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
|
||||
const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
|
||||
const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
|
||||
|
||||
WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames);
|
||||
|
||||
TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
|
||||
TF_ASSERT_OK(
|
||||
ReadFileToString(env, internal_h_file_path, internal_h_file_text));
|
||||
}
|
||||
|
||||
TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
visibility: HIDDEN
|
||||
}
|
||||
)";
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
|
||||
string h_file_text, internal_h_file_text;
|
||||
// Without ApiDef
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(h_file_text, "class Foo");
|
||||
ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
|
||||
|
||||
// With ApiDef
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(internal_h_file_text, "class Foo");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
|
||||
}
|
||||
|
||||
TEST(CcOpGenTest, TestArgNameChanges) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
arg_order: "dim"
|
||||
arg_order: "images"
|
||||
}
|
||||
)";
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
|
||||
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
string cc_file_text, h_file_text;
|
||||
string internal_cc_file_text, internal_h_file_text;
|
||||
// Without ApiDef
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
|
||||
|
||||
// With ApiDef
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
|
||||
}
|
||||
|
||||
TEST(CcOpGenTest, TestEndpoints) {
|
||||
const string api_def = R"(
|
||||
op {
|
||||
graph_op_name: "Foo"
|
||||
endpoint {
|
||||
name: "Foo1"
|
||||
}
|
||||
endpoint {
|
||||
name: "Foo2"
|
||||
}
|
||||
}
|
||||
)";
|
||||
Env* env = Env::Default();
|
||||
OpList op_defs;
|
||||
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
|
||||
|
||||
ApiDefMap api_def_map(op_defs);
|
||||
string cc_file_text, h_file_text;
|
||||
string internal_cc_file_text, internal_h_file_text;
|
||||
// Without ApiDef
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(h_file_text, "class Foo {");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
|
||||
|
||||
// With ApiDef
|
||||
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
|
||||
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
|
||||
&internal_h_file_text);
|
||||
ExpectHasSubstr(h_file_text, "class Foo1");
|
||||
ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
|
||||
ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -83,7 +83,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
|
||||
${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc
|
||||
${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h
|
||||
${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.cc
|
||||
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal}
|
||||
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api
|
||||
DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir
|
||||
)
|
||||
|
||||
|
@ -3371,7 +3371,7 @@ tf_cc_test(
|
||||
|
||||
filegroup(
|
||||
name = "base_api_def",
|
||||
data = glob(["api_def/base_api/*"]),
|
||||
srcs = glob(["api_def/base_api/*"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
@ -3386,10 +3386,6 @@ tf_cc_test(
|
||||
":base_api_def",
|
||||
"//tensorflow/cc:ops/op_gen_overrides.pbtxt",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":framework",
|
||||
":framework_internal",
|
||||
|
@ -221,9 +221,18 @@ std::unordered_map<string, ApiDefs> GenerateApiDef(
|
||||
|
||||
std::unordered_map<string, ApiDefs> api_defs_map;
|
||||
|
||||
// These ops are included in OpList only if TF_NEED_GCP
|
||||
// is set to true. So, we skip them for now so that this test passes
|
||||
// whether TF_NEED_GCP is set or not.
|
||||
const std::unordered_set<string> ops_to_exclude = {
|
||||
"BigQueryReader", "GenerateBigQueryReaderPartitions"};
|
||||
for (const auto& op : ops.op()) {
|
||||
CHECK(!op.name().empty())
|
||||
<< "Encountered empty op name: %s" << op.DebugString();
|
||||
if (ops_to_exclude.find(op.name()) != ops_to_exclude.end()) {
|
||||
LOG(INFO) << "Skipping " << op.name();
|
||||
continue;
|
||||
}
|
||||
string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat);
|
||||
file_path = strings::Printf(file_path.c_str(), op.name().c_str());
|
||||
ApiDef* api_def = api_defs_map[file_path].add_op();
|
||||
|
65
tensorflow/core/api_def/base_api/api_def_ApplyAddSign.pbtxt
Normal file
65
tensorflow/core/api_def/base_api/api_def_ApplyAddSign.pbtxt
Normal file
@ -0,0 +1,65 @@
|
||||
op {
|
||||
graph_op_name: "ApplyAddSign"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "m"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alpha"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sign_decay"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and m tensors is
|
||||
protected by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the AddSign update."
|
||||
description: <<END
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
|
||||
update <- (alpha + sign_decay * sign(g) *sign(m)) * g
|
||||
variable <- variable - lr_t * update
|
||||
END
|
||||
}
|
@ -0,0 +1,65 @@
|
||||
op {
|
||||
graph_op_name: "ApplyPowerSign"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "m"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "logbase"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sign_decay"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "out"
|
||||
description: <<END
|
||||
Same as "var".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and m tensors is
|
||||
protected by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the AddSign update."
|
||||
description: <<END
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
|
||||
update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g
|
||||
variable <- variable - lr_t * update
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "BytesProducedStatsDataset"
|
||||
summary: "Records the bytes size of each element of `input_dataset` in a StatsAggregator."
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
op {
|
||||
graph_op_name: "DeserializeSparse"
|
||||
in_arg {
|
||||
name: "serialized_sparse"
|
||||
description: <<END
|
||||
1-D, The serialized `SparseTensor` object. Must have 3 columns.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The `dtype` of the serialized `SparseTensor` object.
|
||||
END
|
||||
}
|
||||
summary: "Deserialize `SparseTensor` from a (serialized) string 3-vector (1-D `Tensor`)"
|
||||
description: <<END
|
||||
object.
|
||||
END
|
||||
}
|
@ -36,6 +36,13 @@ END
|
||||
name: "num_new_vocab"
|
||||
description: <<END
|
||||
Number of entries in the new vocab file to remap.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "old_vocab_size"
|
||||
description: <<END
|
||||
Number of entries in the old vocab file to consider. If -1,
|
||||
use the entire old vocabulary.
|
||||
END
|
||||
}
|
||||
summary: "Given a path to new and old vocabulary files, returns a remapping Tensor of"
|
||||
@ -43,7 +50,11 @@ END
|
||||
length `num_new_vocab`, where `remapping[i]` contains the row number in the old
|
||||
vocabulary that corresponds to row `i` in the new vocabulary (starting at line
|
||||
`new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
|
||||
in the new vocabulary is not in the old vocabulary. `num_vocab_offset` enables
|
||||
in the new vocabulary is not in the old vocabulary. The old vocabulary is
|
||||
constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
|
||||
default value of -1.
|
||||
|
||||
`num_vocab_offset` enables
|
||||
use in the partitioned variable case, and should generally be set through
|
||||
examining partitioning info. The format of the files should be a text file,
|
||||
with each line containing a single entity within the vocabulary.
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "IteratorSetStatsAggregator"
|
||||
summary: "Associates the given iterator with the given statistics aggregator."
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "LatencyStatsDataset"
|
||||
summary: "Records the latency of producing `input_dataset` elements in a StatsAggregator."
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
op {
|
||||
graph_op_name: "MatrixExponential"
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
Shape is `[..., M, M]`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Shape is `[..., M, M]`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.linalg.expm
|
||||
@end_compatibility
|
||||
END
|
||||
}
|
||||
summary: "Computes the matrix exponential of one or more square matrices:"
|
||||
description: <<END
|
||||
exp(A) = \sum_{n=0}^\infty A^n/n!
|
||||
|
||||
The exponential is computed using a combination of the scaling and squaring
|
||||
method and the Pade approximation. Details can be founds in:
|
||||
Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
|
||||
revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
|
||||
|
||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||
form square matrices. The output is a tensor of the same shape as the input
|
||||
containing the exponential for all input submatrices `[..., :, :]`.
|
||||
END
|
||||
}
|
@ -26,7 +26,7 @@ When set to True, find the nth-largest value in the vector and vice
|
||||
versa.
|
||||
END
|
||||
}
|
||||
summary: "Finds values of the `n`-th order statistic for the last dmension."
|
||||
summary: "Finds values of the `n`-th order statistic for the last dimension."
|
||||
description: <<END
|
||||
If the input is a vector (rank-1), finds the entries which is the nth-smallest
|
||||
value in the vector and outputs their values as scalar tensor.
|
||||
|
@ -0,0 +1,59 @@
|
||||
op {
|
||||
graph_op_name: "ResourceApplyAddSign"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "m"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alpha"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sign_decay"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and m tensors is
|
||||
protected by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the AddSign update."
|
||||
description: <<END
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
|
||||
update <- (alpha + sign_decay * sign(g) *sign(m)) * g
|
||||
variable <- variable - lr_t * update
|
||||
END
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
op {
|
||||
graph_op_name: "ResourceApplyPowerSign"
|
||||
in_arg {
|
||||
name: "var"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "m"
|
||||
description: <<END
|
||||
Should be from a Variable().
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "lr"
|
||||
description: <<END
|
||||
Scaling factor. Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "logbase"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "sign_decay"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "beta"
|
||||
description: <<END
|
||||
Must be a scalar.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "grad"
|
||||
description: <<END
|
||||
The gradient.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_locking"
|
||||
description: <<END
|
||||
If `True`, updating of the var and m tensors is
|
||||
protected by a lock; otherwise the behavior is undefined, but may exhibit less
|
||||
contention.
|
||||
END
|
||||
}
|
||||
summary: "Update \'*var\' according to the AddSign update."
|
||||
description: <<END
|
||||
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
|
||||
update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g
|
||||
variable <- variable - lr_t * update
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "StatsAggregatorHandle"
|
||||
summary: "Creates a statistics manager resource."
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "StatsAggregatorSummary"
|
||||
summary: "Produces a summary of any statistics recorded by the given statistics manager."
|
||||
}
|
@ -48,6 +48,17 @@ END
|
||||
If true (default), Tensors in the TensorArray are cleared
|
||||
after being read. This disables multiple read semantics but allows early
|
||||
release of memory.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "identical_element_shapes"
|
||||
description: <<END
|
||||
If true (default is false), then all
|
||||
elements in the TensorArray will be expected to have have identical shapes.
|
||||
This allows certain behaviors, like dynamically checking for
|
||||
consistent shapes on write, and being able to fill in properly
|
||||
shaped zero tensors on stack -- even if the element_shape attribute
|
||||
is not fully defined.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "DeserializeSparse"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "MatrixExponential"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -281,6 +281,9 @@ static void StringReplace(const string& from, const string& to, string* s) {
|
||||
} else {
|
||||
split.push_back(s->substr(pos, found - pos));
|
||||
pos = found + from.size();
|
||||
if (pos == s->size()) { // handle case where `from` is at the very end.
|
||||
split.push_back("");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Join the pieces back together with a new delimiter.
|
||||
@ -316,6 +319,36 @@ static void RenameInDocs(const string& from, const string& to, OpDef* op_def) {
|
||||
}
|
||||
}
|
||||
|
||||
static void RenameInDocs(const string& from, const string& to,
|
||||
ApiDef* api_def) {
|
||||
const string from_quoted = strings::StrCat("`", from, "`");
|
||||
const string to_quoted = strings::StrCat("`", to, "`");
|
||||
for (int i = 0; i < api_def->in_arg_size(); ++i) {
|
||||
if (!api_def->in_arg(i).description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted,
|
||||
api_def->mutable_in_arg(i)->mutable_description());
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < api_def->out_arg_size(); ++i) {
|
||||
if (!api_def->out_arg(i).description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted,
|
||||
api_def->mutable_out_arg(i)->mutable_description());
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < api_def->attr_size(); ++i) {
|
||||
if (!api_def->attr(i).description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted,
|
||||
api_def->mutable_attr(i)->mutable_description());
|
||||
}
|
||||
}
|
||||
if (!api_def->summary().empty()) {
|
||||
StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
|
||||
}
|
||||
if (!api_def->description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted, api_def->mutable_description());
|
||||
}
|
||||
}
|
||||
|
||||
const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const {
|
||||
// Look up
|
||||
const auto iter = map_.find(op_def->name());
|
||||
@ -521,6 +554,7 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
|
||||
". All elements in arg_order override must match base arg_order: ",
|
||||
str_util::Join(base_api_def->arg_order(), ", "));
|
||||
}
|
||||
|
||||
base_api_def->clear_arg_order();
|
||||
std::copy(
|
||||
new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
|
||||
@ -608,6 +642,32 @@ Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ApiDefMap::UpdateDocs() {
|
||||
for (auto& name_and_api_def : map_) {
|
||||
auto& api_def = name_and_api_def.second;
|
||||
CHECK_GT(api_def.endpoint_size(), 0);
|
||||
const string canonical_name = api_def.endpoint(0).name();
|
||||
if (api_def.graph_op_name() != canonical_name) {
|
||||
RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
|
||||
}
|
||||
for (const auto& in_arg : api_def.in_arg()) {
|
||||
if (in_arg.name() != in_arg.rename_to()) {
|
||||
RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
|
||||
}
|
||||
}
|
||||
for (const auto& out_arg : api_def.out_arg()) {
|
||||
if (out_arg.name() != out_arg.rename_to()) {
|
||||
RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
|
||||
}
|
||||
}
|
||||
for (const auto& attr : api_def.attr()) {
|
||||
if (attr.name() != attr.rename_to()) {
|
||||
RenameInDocs(attr.name(), attr.rename_to(), &api_def);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
|
||||
return gtl::FindOrNull(map_, name);
|
||||
}
|
||||
|
@ -106,6 +106,12 @@ class ApiDefMap {
|
||||
// passed to the constructor.
|
||||
Status LoadApiDef(const string& api_def_file_contents);
|
||||
|
||||
// Updates ApiDef docs. For example, if ApiDef renames an argument
|
||||
// or attribute, applies these renames to descriptions as well.
|
||||
// UpdateDocs should only be called once after all ApiDefs are loaded
|
||||
// since it replaces original op names.
|
||||
void UpdateDocs();
|
||||
|
||||
// Look up ApiDef proto based on the given graph op name.
|
||||
// If graph op name is not in this ApiDefMap, returns nullptr.
|
||||
//
|
||||
|
@ -455,5 +455,62 @@ op {
|
||||
status = api_map.LoadApiDef(api_def3);
|
||||
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||
}
|
||||
|
||||
TEST(OpGenLibTest, ApiDefUpdateDocs) {
|
||||
const string op_list1 = R"(op {
|
||||
name: "testop"
|
||||
input_arg {
|
||||
name: "arg_a"
|
||||
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
|
||||
}
|
||||
output_arg {
|
||||
name: "arg_c"
|
||||
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
|
||||
}
|
||||
attr {
|
||||
name: "attr_a"
|
||||
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
|
||||
}
|
||||
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
|
||||
}
|
||||
)";
|
||||
|
||||
const string api_def1 = R"(
|
||||
op {
|
||||
graph_op_name: "testop"
|
||||
endpoint {
|
||||
name: "testop2"
|
||||
}
|
||||
in_arg {
|
||||
name: "arg_a"
|
||||
rename_to: "arg_aa"
|
||||
}
|
||||
out_arg {
|
||||
name: "arg_c"
|
||||
rename_to: "arg_cc"
|
||||
description: "New description: `arg_a`, `arg_c`, `attr_a`, `testop`"
|
||||
}
|
||||
attr {
|
||||
name: "attr_a"
|
||||
rename_to: "attr_aa"
|
||||
}
|
||||
}
|
||||
)";
|
||||
OpList op_list;
|
||||
protobuf::TextFormat::ParseFromString(op_list1, &op_list); // NOLINT
|
||||
ApiDefMap api_map(op_list);
|
||||
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
|
||||
api_map.UpdateDocs();
|
||||
|
||||
const string expected_description =
|
||||
"`arg_aa`, `arg_cc`, `attr_aa`, `testop2`";
|
||||
EXPECT_EQ(expected_description, api_map.GetApiDef("testop")->description());
|
||||
EXPECT_EQ(expected_description,
|
||||
api_map.GetApiDef("testop")->in_arg(0).description());
|
||||
EXPECT_EQ("New description: " + expected_description,
|
||||
api_map.GetApiDef("testop")->out_arg(0).description());
|
||||
EXPECT_EQ(expected_description,
|
||||
api_map.GetApiDef("testop")->attr(0).description());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -316,7 +316,9 @@ def tf_gen_op_wrapper_cc(name,
|
||||
op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
|
||||
deps=None,
|
||||
override_file=None,
|
||||
include_internal_ops=0):
|
||||
include_internal_ops=0,
|
||||
# ApiDefs will be loaded in the order specified in this list.
|
||||
api_def_srcs=[]):
|
||||
# Construct an op generator binary for these ops.
|
||||
tool = out_ops_file + "_gen_cc"
|
||||
if deps == None:
|
||||
@ -328,12 +330,26 @@ def tf_gen_op_wrapper_cc(name,
|
||||
linkstatic=1, # Faster to link this one-time-use binary dynamically
|
||||
deps=[op_gen] + deps)
|
||||
|
||||
srcs = api_def_srcs[:]
|
||||
|
||||
if override_file == None:
|
||||
srcs = []
|
||||
override_arg = ","
|
||||
else:
|
||||
srcs = [override_file]
|
||||
srcs += [override_file]
|
||||
override_arg = "$(location " + override_file + ")"
|
||||
|
||||
if not api_def_srcs:
|
||||
api_def_args_str = ","
|
||||
else:
|
||||
api_def_args = []
|
||||
for api_def_src in api_def_srcs:
|
||||
# Add directory of the first ApiDef source to args.
|
||||
# We are assuming all ApiDefs in a single api_def_src are in the
|
||||
# same directory.
|
||||
api_def_args.append(
|
||||
" $$(dirname $$(echo $(locations " + api_def_src +
|
||||
") | cut -d\" \" -f1))")
|
||||
api_def_args_str = ",".join(api_def_args)
|
||||
native.genrule(
|
||||
name=name + "_genrule",
|
||||
outs=[
|
||||
@ -344,7 +360,7 @@ def tf_gen_op_wrapper_cc(name,
|
||||
tools=[":" + tool] + tf_binary_additional_srcs(),
|
||||
cmd=("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " +
|
||||
"$(location :" + out_ops_file + ".cc) " + override_arg + " " +
|
||||
str(include_internal_ops)))
|
||||
str(include_internal_ops) + " " + api_def_args_str))
|
||||
|
||||
|
||||
# Given a list of "op_lib_names" (a list of files in the ops directory
|
||||
@ -387,7 +403,9 @@ def tf_gen_op_wrappers_cc(name,
|
||||
op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
|
||||
override_file=None,
|
||||
include_internal_ops=0,
|
||||
visibility=None):
|
||||
visibility=None,
|
||||
# ApiDefs will be loaded in the order apecified in this list.
|
||||
api_def_srcs=[]):
|
||||
subsrcs = other_srcs[:]
|
||||
subhdrs = other_hdrs[:]
|
||||
internalsrcs = []
|
||||
@ -399,7 +417,8 @@ def tf_gen_op_wrappers_cc(name,
|
||||
pkg=pkg,
|
||||
op_gen=op_gen,
|
||||
override_file=override_file,
|
||||
include_internal_ops=include_internal_ops)
|
||||
include_internal_ops=include_internal_ops,
|
||||
api_def_srcs=api_def_srcs)
|
||||
subsrcs += ["ops/" + n + ".cc"]
|
||||
subhdrs += ["ops/" + n + ".h"]
|
||||
internalsrcs += ["ops/" + n + "_internal.cc"]
|
||||
|
Loading…
Reference in New Issue
Block a user