Add a mechanism for hiding, skipping, and modifying the generated op
functions for C++. A souped-up version of the hidden_ops mechanism in Python, the intent is to use this for most or all of the client languages, with a common list of changes to make in a common file and per-language overrides. Also: * include the documentation for outputs in the generated comments * several updates to C++ API to match Python * fix C++ shape function for ConcatV2 now that we use it by default * split op_gen_lib out of core:framework, since it is only used by the op generators, and I don't want to add another proto to mobile builds Change: 146267344
This commit is contained in:
parent
287e845c52
commit
8fe32029f7
tensorflow
cc
contrib/makefile
core
examples
python
tensorflow.bzltools/graph_transforms
@ -228,6 +228,7 @@ cc_library(
|
||||
srcs = ["gradients/array_grad.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":cc_ops_internal",
|
||||
":grad_op_registry",
|
||||
":gradients",
|
||||
],
|
||||
@ -239,6 +240,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":array_grad",
|
||||
":cc_ops",
|
||||
":cc_ops_internal",
|
||||
":grad_op_registry",
|
||||
":grad_testutil",
|
||||
":gradient_checker",
|
||||
@ -334,6 +336,7 @@ tf_gen_op_wrappers_cc(
|
||||
"ops/const_op.h",
|
||||
"ops/standard_ops.h",
|
||||
],
|
||||
override_file = "ops/op_gen_overrides.pbtxt",
|
||||
pkg = "//tensorflow/core",
|
||||
)
|
||||
|
||||
@ -387,6 +390,7 @@ cc_library_with_android_deps(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
@ -38,7 +39,7 @@ const int kRightMargin = 79;
|
||||
// Converts:
|
||||
// bazel-out/.../genfiles/(external/YYY/)?XX
|
||||
// to: XX.
|
||||
string GetPath(const std::string& dot_h_fname) {
|
||||
string GetPath(const string& dot_h_fname) {
|
||||
auto pos = dot_h_fname.find("/genfiles/");
|
||||
string result = dot_h_fname;
|
||||
if (pos != string::npos) {
|
||||
@ -60,7 +61,7 @@ string GetPath(const std::string& dot_h_fname) {
|
||||
// cc/ops/gen_foo_ops.h
|
||||
// to:
|
||||
// CC_OPS_GEN_FOO_OPS_H_
|
||||
string ToGuard(const std::string& path) {
|
||||
string ToGuard(const string& path) {
|
||||
string guard;
|
||||
guard.reserve(path.size() + 1); // + 1 -> trailing _
|
||||
for (const char c : path) {
|
||||
@ -360,7 +361,12 @@ bool HasOptionalAttrs(
|
||||
}
|
||||
|
||||
struct OpInfo {
|
||||
explicit OpInfo(const OpDef& op_def);
|
||||
// graph_op_def: The OpDef used by the runtime, has the names that
|
||||
// must be used when calling NodeBuilder.
|
||||
// interface_op_def: The OpDef used in the interface in the generated
|
||||
// code, with possibly overridden names and defaults.
|
||||
explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def,
|
||||
const std::vector<string>& aliases);
|
||||
string GetOpAttrStruct() const;
|
||||
string GetConstructorDecl(StringPiece op_name_prefix,
|
||||
bool include_attr) const;
|
||||
@ -378,11 +384,15 @@ struct OpInfo {
|
||||
bool has_optional_attrs;
|
||||
string comment;
|
||||
|
||||
const OpDef& graph_op_def;
|
||||
const OpDef& op_def;
|
||||
const std::vector<string>& aliases;
|
||||
std::unordered_map<string, string> inferred_input_attrs;
|
||||
};
|
||||
|
||||
OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
|
||||
OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
|
||||
const std::vector<string>& a)
|
||||
: graph_op_def(g_op_def), op_def(i_op_def), aliases(a) {
|
||||
op_name = op_def.name();
|
||||
InferOpAttributes(op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
|
||||
@ -443,7 +453,6 @@ OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
|
||||
strings::StrAppend(&comment, " ", attr.description(), "\n");
|
||||
}
|
||||
}
|
||||
comment = MakeComment(comment, "");
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
const auto& arg = op_def.output_arg(i);
|
||||
@ -453,13 +462,55 @@ OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
|
||||
output_names.push_back(AvoidCPPKeywords(arg.name()));
|
||||
is_list_output.push_back(is_list);
|
||||
}
|
||||
|
||||
strings::StrAppend(&comment, "\nReturns:\n");
|
||||
if (op_def.output_arg_size() == 0) { // No outputs.
|
||||
strings::StrAppend(&comment, "* the created `Operation`\n");
|
||||
} else if (op_def.output_arg_size() == 1) { // One output
|
||||
if (is_list_output[0]) {
|
||||
strings::StrAppend(&comment, "* `OutputList`: ");
|
||||
} else {
|
||||
strings::StrAppend(&comment, "* `Output`: ");
|
||||
}
|
||||
if (op_def.output_arg(0).description().empty()) {
|
||||
strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(),
|
||||
" tensor.\n");
|
||||
} else {
|
||||
// TODO(josh11b): Word wrap this.
|
||||
strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n");
|
||||
}
|
||||
} else { // Multiple outputs.
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
if (is_list_output[i]) {
|
||||
strings::StrAppend(&comment, "* `OutputList`");
|
||||
} else {
|
||||
strings::StrAppend(&comment, "* `Output`");
|
||||
}
|
||||
strings::StrAppend(&comment, " ", output_names[i]);
|
||||
if (op_def.output_arg(i).description().empty()) {
|
||||
strings::StrAppend(&comment, "\n");
|
||||
} else {
|
||||
// TODO(josh11b): Word wrap this.
|
||||
strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(),
|
||||
"\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!aliases.empty()) {
|
||||
strings::StrAppend(&comment, "\nAliases:\n");
|
||||
for (const auto& alias : aliases) {
|
||||
strings::StrAppend(&comment, "* ", alias, "\n");
|
||||
}
|
||||
}
|
||||
comment = MakeComment(comment, "");
|
||||
}
|
||||
|
||||
string OpInfo::GetOpAttrStruct() const {
|
||||
string struct_fields;
|
||||
string setters;
|
||||
string attrs_comment = strings::StrCat("Optional attribute setters for ",
|
||||
op_def.name(), " :\n\n");
|
||||
string attrs_comment =
|
||||
strings::StrCat("Optional attribute setters for ", op_name, " :\n\n");
|
||||
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
@ -603,7 +654,13 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
|
||||
";\n");
|
||||
}
|
||||
|
||||
strings::StrAppend(&class_decl, "};\n\n");
|
||||
strings::StrAppend(&class_decl, "};\n");
|
||||
if (!aliases.empty()) {
|
||||
for (const auto& alias : aliases) {
|
||||
strings::StrAppend(&class_decl, "typedef ", op_name, " ", alias, ";\n");
|
||||
}
|
||||
}
|
||||
strings::StrAppend(&class_decl, "\n");
|
||||
TF_CHECK_OK(h->Append(class_decl));
|
||||
}
|
||||
|
||||
@ -642,7 +699,7 @@ void OpInfo::GetOutput(string* out) const {
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
const string arg_range = strings::StrCat(
|
||||
"_outputs_range[\"", op_def.output_arg(i).name(), "\"]");
|
||||
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
|
||||
if (is_list_output[i]) {
|
||||
strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ",
|
||||
arg_range, ".second; ++i)\n");
|
||||
@ -673,17 +730,18 @@ string OpInfo::GetConstructorBody() const {
|
||||
}
|
||||
|
||||
strings::StrAppend(&body, " ::tensorflow::Node* ret;\n");
|
||||
strings::StrAppend(&body, " const auto unique_name = ", scope_str,
|
||||
".GetUniqueNameForOp(\"", op_def.name(), "\");\n");
|
||||
strings::StrAppend(&body, " const auto unique_name = ", scope_str,
|
||||
".GetUniqueNameForOp(\"", op_name, "\");\n");
|
||||
strings::StrAppend(
|
||||
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
|
||||
op_def.name(), "\")\n");
|
||||
graph_op_def.name(), "\")\n");
|
||||
const string spaces = " ";
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
|
||||
}
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& graph_attr(graph_op_def.attr(i));
|
||||
const auto& attr(op_def.attr(i));
|
||||
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
|
||||
continue;
|
||||
@ -691,7 +749,7 @@ string OpInfo::GetConstructorBody() const {
|
||||
const string attr_name = attr.has_default_value()
|
||||
? strings::StrCat("attrs.", attr.name(), "_")
|
||||
: AvoidCPPKeywords(attr.name());
|
||||
strings::StrAppend(&body, spaces, ".Attr(\"", attr.name(), "\", ",
|
||||
strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
|
||||
attr_name, ")\n");
|
||||
}
|
||||
strings::StrAppend(&body, " ;\n");
|
||||
@ -736,23 +794,17 @@ void OpInfo::WriteClassDef(WritableFile* cc) const {
|
||||
TF_CHECK_OK(cc->Append(class_def));
|
||||
}
|
||||
|
||||
void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
|
||||
OpInfo op_info(op_def);
|
||||
void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def,
|
||||
const std::vector<string>& aliases, WritableFile* h,
|
||||
WritableFile* cc) {
|
||||
OpInfo op_info(graph_op_def, interface_op_def, aliases);
|
||||
|
||||
op_info.WriteClassDecl(h);
|
||||
op_info.WriteClassDef(cc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
|
||||
const std::string& dot_cc_fname) {
|
||||
Env* env = Env::Default();
|
||||
std::unique_ptr<WritableFile> h = nullptr;
|
||||
std::unique_ptr<WritableFile> cc = nullptr;
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
|
||||
|
||||
void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h,
|
||||
WritableFile* cc, string* op_header_guard) {
|
||||
const string header =
|
||||
R"header(// This file is MACHINE GENERATED! Do not edit.
|
||||
|
||||
@ -765,18 +817,22 @@ void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
|
||||
)header";
|
||||
|
||||
// TODO(keveman): Make namespaces configurable.
|
||||
const string namespace_begin = R"namespace(
|
||||
const string namespace_begin = internal ? R"namespace(
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace internal {
|
||||
// NOTE: This namespace has internal TensorFlow details that
|
||||
// are not part of TensorFlow's public API.
|
||||
|
||||
)namespace"
|
||||
: R"namespace(
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
)namespace";
|
||||
|
||||
const string footer = R"footer(} // namespace ops
|
||||
} // namespace tensorflow
|
||||
)footer";
|
||||
|
||||
const string op_header = GetPath(dot_h_fname);
|
||||
const string op_header_guard = ToGuard(op_header);
|
||||
*op_header_guard = ToGuard(op_header);
|
||||
const string cc_header = strings::StrCat(
|
||||
R"include(// This file is MACHINE GENERATED! Do not edit.
|
||||
|
||||
@ -788,27 +844,25 @@ namespace ops {
|
||||
TF_CHECK_OK(h->Append(
|
||||
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
|
||||
"#ifndef ",
|
||||
op_header_guard,
|
||||
*op_header_guard,
|
||||
"\n"
|
||||
"#define ",
|
||||
op_header_guard, "\n\n")));
|
||||
*op_header_guard, "\n\n")));
|
||||
TF_CHECK_OK(h->Append(header));
|
||||
TF_CHECK_OK(h->Append(namespace_begin));
|
||||
TF_CHECK_OK(cc->Append(cc_header));
|
||||
}
|
||||
|
||||
for (const auto& op_def : ops.op()) {
|
||||
// Skip deprecated ops.
|
||||
// TODO(josh11b): If needed, can put them into a "deprecated" namespace
|
||||
// instead of skipping.
|
||||
if (op_def.has_deprecation() &&
|
||||
op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
|
||||
continue;
|
||||
}
|
||||
// We use a hand-written wrapper for "Const", since the generated
|
||||
// code depends on it.
|
||||
if (op_def.name() == "Const") continue;
|
||||
WriteCCOp(op_def, h.get(), cc.get());
|
||||
}
|
||||
void FinishFiles(bool internal, WritableFile* h, WritableFile* cc,
|
||||
const string& op_header_guard) {
|
||||
const string footer = internal ? R"footer(} // namespace internal
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
)footer"
|
||||
:
|
||||
R"footer(} // namespace ops
|
||||
} // namespace tensorflow
|
||||
)footer";
|
||||
|
||||
TF_CHECK_OK(h->Append(footer));
|
||||
TF_CHECK_OK(
|
||||
@ -819,4 +873,82 @@ namespace ops {
|
||||
TF_CHECK_OK(h->Close());
|
||||
}
|
||||
|
||||
string MakeInternal(const string& fname) {
|
||||
auto dot_pos = fname.rfind('.');
|
||||
if (dot_pos == string::npos) {
|
||||
return strings::StrCat(fname, "_internal");
|
||||
} else {
|
||||
return strings::StrCat(fname.substr(0, dot_pos), "_internal",
|
||||
fname.substr(dot_pos));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
const string& dot_cc_fname, const string& overrides_fnames) {
|
||||
Env* env = Env::Default();
|
||||
|
||||
// Load the override map.
|
||||
OpGenOverrideMap override_map;
|
||||
if (!overrides_fnames.empty()) {
|
||||
override_map.LoadFileList(env, overrides_fnames);
|
||||
}
|
||||
|
||||
// Write the initial boilerplate to the .h and .cc files.
|
||||
std::unique_ptr<WritableFile> h = nullptr;
|
||||
std::unique_ptr<WritableFile> cc = nullptr;
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
|
||||
string op_header_guard;
|
||||
StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard);
|
||||
|
||||
// Create the internal versions of these files for the hidden ops.
|
||||
std::unique_ptr<WritableFile> internal_h = nullptr;
|
||||
std::unique_ptr<WritableFile> internal_cc = nullptr;
|
||||
const string internal_dot_h_fname = MakeInternal(dot_h_fname);
|
||||
TF_CHECK_OK(env->NewWritableFile(internal_dot_h_fname, &internal_h));
|
||||
TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc));
|
||||
string internal_op_header_guard;
|
||||
StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(),
|
||||
internal_cc.get(), &internal_op_header_guard);
|
||||
|
||||
for (const auto& graph_op_def : ops.op()) {
|
||||
// Skip deprecated ops.
|
||||
// TODO(josh11b): If needed, can put them into a "deprecated" namespace
|
||||
// instead of skipping.
|
||||
if (graph_op_def.has_deprecation() &&
|
||||
graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We use a hand-written wrapper for "Const", since the generated
|
||||
// code depends on it.
|
||||
if (graph_op_def.name() == "Const") continue;
|
||||
|
||||
// Incorporate overrides from override_map.
|
||||
OpDef interface_op_def = graph_op_def;
|
||||
const OpGenOverride* op_override =
|
||||
override_map.ApplyOverride(&interface_op_def);
|
||||
std::vector<string> aliases;
|
||||
if (op_override) {
|
||||
if (op_override->skip()) continue;
|
||||
aliases.assign(op_override->alias().begin(), op_override->alias().end());
|
||||
if (op_override->hide()) {
|
||||
// Write hidden ops to _internal.h and _internal.cc.
|
||||
WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(),
|
||||
internal_cc.get());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// This isn't a hidden op, write it to the main files.
|
||||
WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get());
|
||||
}
|
||||
|
||||
FinishFiles(false, h.get(), cc.get(), op_header_guard);
|
||||
FinishFiles(true /* internal */, internal_h.get(), internal_cc.get(),
|
||||
internal_op_header_guard);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -17,12 +17,13 @@ limitations under the License.
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Result is written to files dot_h and dot_cc.
|
||||
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
|
||||
const std::string& dot_cc_fname);
|
||||
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
|
||||
const string& dot_cc_fname, const string& overrides_fnames);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -24,10 +24,10 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
|
||||
bool include_internal) {
|
||||
const std::string& overrides_fnames, bool include_internal) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(include_internal, &ops);
|
||||
WriteCCOps(ops, dot_h, dot_cc);
|
||||
WriteCCOps(ops, dot_h, dot_cc, overrides_fnames);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -35,15 +35,18 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
if (argc != 4) {
|
||||
if (argc != 5) {
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
|
||||
}
|
||||
fprintf(stderr,
|
||||
"Usage: %s out.h out.cc include_internal\n"
|
||||
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n"
|
||||
" include_internal: 1 means include internal ops\n",
|
||||
argv[0]);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool include_internal = tensorflow::StringPiece("1") == argv[3];
|
||||
tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal);
|
||||
bool include_internal = tensorflow::StringPiece("1") == argv[4];
|
||||
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal);
|
||||
return 0;
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ TEST(CCOpTest, Attrs) {
|
||||
TEST(CCOpTest, SplitConcat) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Split p(root, 0, {{1}, {2}}, 2);
|
||||
auto c = Concat(root, 0, {p[0], p[1]});
|
||||
auto c = Concat(root, {p[0], p[1]}, 0);
|
||||
TF_EXPECT_OK(root.status());
|
||||
Tensor out;
|
||||
test::GetTensor(root, c, &out);
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
@ -92,7 +93,7 @@ Status SplitGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
grad_outputs->push_back(NoGradient());
|
||||
grad_outputs->push_back(Concat(scope, op.input(0), grad_inputs));
|
||||
grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Split", SplitGrad);
|
||||
@ -219,7 +220,7 @@ Status ReverseGrad(const Scope& scope, const Operation& op,
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Reverse", ReverseGrad);
|
||||
REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
|
||||
|
||||
Status ScatterNdGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
@ -319,8 +320,8 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
string mode;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
|
||||
grad_outputs->push_back(
|
||||
tensorflow::ops::MirrorPadGrad(scope, grad_inputs[0], op.input(1), mode));
|
||||
grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
|
||||
scope, grad_inputs[0], op.input(1), mode));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
|
@ -17,12 +17,14 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
using ops::internal::MirrorPadGrad;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -207,8 +209,7 @@ TEST_F(ArrayGradTest, ReverseSequenceGrad) {
|
||||
TEST_F(ArrayGradTest, ReverseGrad) {
|
||||
TensorShape shape({5, 2, 5});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
auto reverse_dims = Const(scope_, {true, false, true});
|
||||
auto y = Reverse(scope_, x, reverse_dims);
|
||||
auto y = Reverse(scope_, x, {0, 2});
|
||||
RunTest(x, shape, y, shape);
|
||||
}
|
||||
|
||||
|
42
tensorflow/cc/ops/op_gen_overrides.pbtxt
Normal file
42
tensorflow/cc/ops/op_gen_overrides.pbtxt
Normal file
@ -0,0 +1,42 @@
|
||||
# array_ops
|
||||
op { name: "BroadcastArgs" rename_to: "BroadcastDynamicShape" }
|
||||
op { name: "ConcatOffset" skip: true } # Maybe should just be hidden?
|
||||
op { name: "Concat" skip: true }
|
||||
op { name: "ConcatV2" rename_to: "Concat" }
|
||||
op { name: "MirrorPadGrad" hide: true }
|
||||
op { name: "Reverse" skip: true }
|
||||
op { name: "ReverseV2" rename_to: "Reverse" }
|
||||
|
||||
# candidate_sampling_ops
|
||||
# control_flow_ops
|
||||
# ctc_ops
|
||||
# data_flow_ops
|
||||
op { name: "FakeQueue" skip: true }
|
||||
|
||||
# functional_ops
|
||||
# image_ops
|
||||
# io_ops
|
||||
# linalg_ops
|
||||
# logging_ops
|
||||
# math_ops
|
||||
op { name: "All" alias: "ReduceAll" input_rename: { from: "reduction_indices" to: "axis" } }
|
||||
op { name: "Any" alias: "ReduceAny" input_rename: { from: "reduction_indices" to: "axis" } }
|
||||
op { name: "Max" alias: "ReduceMax" input_rename: { from: "reduction_indices" to: "axis" } }
|
||||
op { name: "Mean" alias: "ReduceMean" input_rename: { from: "reduction_indices" to: "axis" } }
|
||||
op { name: "Min" alias: "ReduceMin" input_rename: { from: "reduction_indices" to: "axis" } }
|
||||
op { name: "Prod" alias: "ReduceProd" input_rename: { from: "reduction_indices" to: "axis" } }
|
||||
op { name: "Sum" alias: "ReduceSum" input_rename: { from: "reduction_indices" to: "axis" } }
|
||||
|
||||
# nn_ops
|
||||
op { name: "TopKV2" rename_to: "TopK" alias: "TopKV2" } # TODO(josh11b): delete "TopKV2" alias once users updated
|
||||
|
||||
# parsing_ops
|
||||
# random_ops
|
||||
# script_ops
|
||||
# sdca_ops
|
||||
# state_ops
|
||||
# sparse_ops
|
||||
# string_ops
|
||||
# user_ops
|
||||
# training_ops
|
||||
# word2vec deprecated ops
|
@ -454,6 +454,7 @@ $(wildcard tensorflow/core/*/*/*testutil*) \
|
||||
$(wildcard tensorflow/core/*/*/*testlib*) \
|
||||
$(wildcard tensorflow/core/*/*/*main.cc) \
|
||||
$(wildcard tensorflow/core/debug/*.cc) \
|
||||
$(wildcard tensorflow/core/framework/op_gen_lib.cc) \
|
||||
$(wildcard tensorflow/core/graph/dot.*) \
|
||||
$(wildcard tensorflow/core/lib/gif/*) \
|
||||
$(wildcard tensorflow/core/lib/io/zlib*) \
|
||||
|
@ -328,7 +328,6 @@ tf_cuda_library(
|
||||
"framework/op.h",
|
||||
"framework/op_def_builder.h",
|
||||
"framework/op_def_util.h",
|
||||
"framework/op_gen_lib.h",
|
||||
"framework/op_kernel.h",
|
||||
"framework/partial_tensor_shape.h",
|
||||
"framework/queue_interface.h",
|
||||
@ -384,6 +383,25 @@ tf_cuda_library(
|
||||
deps = [":framework_internal"],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "op_gen_overrides_proto",
|
||||
srcs = ["framework/op_gen_overrides.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = [":protos_all"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "op_gen_lib",
|
||||
srcs = ["framework/op_gen_lib.cc"],
|
||||
hdrs = ["framework/op_gen_lib.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":lib",
|
||||
":op_gen_overrides_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "framework_lite",
|
||||
srcs = tf_additional_minimal_lib_srcs(),
|
||||
@ -761,6 +779,7 @@ filegroup(
|
||||
"**/*testlib*",
|
||||
"**/*main.cc",
|
||||
"debug/**/*",
|
||||
"framework/op_gen_*",
|
||||
"graph/dot.*",
|
||||
"lib/jpeg/**/*",
|
||||
"lib/png/**/*",
|
||||
@ -1250,6 +1269,7 @@ tf_cuda_library(
|
||||
"util/reporter.h",
|
||||
"util/reporter.cc",
|
||||
"framework/fake_input.*",
|
||||
"framework/op_gen_lib.*",
|
||||
"util/memmapped_file_system.*",
|
||||
"util/memmapped_file_system_writer.*",
|
||||
"util/version_info.cc",
|
||||
|
@ -397,18 +397,22 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
|
||||
}
|
||||
}
|
||||
*result = target_context->MakeShape(dims);
|
||||
} else if (src_op == "Concat") {
|
||||
} else if (src_op == "Concat" || src_op == "ConcatV2") {
|
||||
*result = target_context->Scalar();
|
||||
// For Concat, input 0 is concat dim; for V2 it is the last input.
|
||||
const int concat_dim =
|
||||
src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
|
||||
// Concat is concatenating its input shape vectors.
|
||||
// input 0 is ignored as it is the concat dim and will always be 0.
|
||||
for (int i = 1; i < src_context->num_inputs(); ++i) {
|
||||
for (int i = 0; i < src_context->num_inputs(); ++i) {
|
||||
// Concat dim is ignored (and will always be a scalar).
|
||||
if (i == concat_dim) continue;
|
||||
ShapeHandle sub_result;
|
||||
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
|
||||
i, &sub_result));
|
||||
if (!target_context->RankKnown(sub_result)) {
|
||||
// Failed to evaluate. Treat the output as completely unknown.
|
||||
// TODO(cwhipkey): we could rely on all inputs being the same size, so
|
||||
// figure that size out and append the right number of unknown dims.
|
||||
// TODO(cwhipkey): we could rely on all inputs being the same rank, so
|
||||
// figure that rank out and append the right number of unknown dims.
|
||||
*result = target_context->UnknownShape();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -642,7 +642,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
||||
const_input,
|
||||
}; // clang-format on
|
||||
auto concat_dim = ops::Const(root, 0);
|
||||
auto concat = ops::Concat(root, concat_dim, concat_inputs);
|
||||
auto concat = ops::Concat(root, concat_inputs, concat_dim);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
@ -684,7 +684,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
|
||||
ops::Shape(root, Output(unknown)),
|
||||
}; // clang-format on
|
||||
auto concat_dim = ops::Const(root, 0);
|
||||
auto concat = ops::Concat(root, concat_dim, concat_inputs);
|
||||
auto concat = ops::Concat(root, concat_inputs, concat_dim);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
@ -726,7 +726,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
|
||||
const_input,
|
||||
}; // clang-format on
|
||||
auto concat_dim = ops::Const(root, 0);
|
||||
auto concat = ops::Concat(root, concat_dim, concat_inputs);
|
||||
auto concat = ops::Concat(root, concat_inputs, concat_dim);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
Node* result;
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -67,4 +70,136 @@ bool ConsumeEquals(StringPiece* description) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) {
|
||||
std::vector<string> v = str_util::Split(filenames, ",");
|
||||
for (const string& f : v) {
|
||||
TF_RETURN_IF_ERROR(LoadFile(env, f));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) {
|
||||
if (filename.empty()) return Status::OK();
|
||||
string contents;
|
||||
TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
|
||||
OpGenOverrides all;
|
||||
protobuf::TextFormat::ParseFromString(contents, &all);
|
||||
for (const auto& one : all.op()) {
|
||||
map_[one.name()] = one;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void StringReplace(const string& from, const string& to, string* s) {
|
||||
std::vector<string> rest = str_util::Split(*s, from);
|
||||
*s = str_util::Join(rest, to.c_str());
|
||||
}
|
||||
|
||||
static void RenameInDocs(const string& from, const string& to, OpDef* op_def) {
|
||||
const string from_quoted = strings::StrCat("`", from, "`");
|
||||
const string to_quoted = strings::StrCat("`", to, "`");
|
||||
for (int i = 0; i < op_def->input_arg_size(); ++i) {
|
||||
if (!op_def->input_arg(i).description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted,
|
||||
op_def->mutable_input_arg(i)->mutable_description());
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < op_def->output_arg_size(); ++i) {
|
||||
if (!op_def->output_arg(i).description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted,
|
||||
op_def->mutable_output_arg(i)->mutable_description());
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < op_def->attr_size(); ++i) {
|
||||
if (!op_def->attr(i).description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted,
|
||||
op_def->mutable_attr(i)->mutable_description());
|
||||
}
|
||||
}
|
||||
if (!op_def->summary().empty()) {
|
||||
StringReplace(from_quoted, to_quoted, op_def->mutable_summary());
|
||||
}
|
||||
if (!op_def->description().empty()) {
|
||||
StringReplace(from_quoted, to_quoted, op_def->mutable_summary());
|
||||
}
|
||||
}
|
||||
|
||||
const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const {
|
||||
// Look up
|
||||
const auto iter = map_.find(op_def->name());
|
||||
if (iter == map_.end()) return nullptr;
|
||||
const OpGenOverride& proto = iter->second;
|
||||
|
||||
// Apply overrides from `proto`.
|
||||
if (!proto.rename_to().empty()) {
|
||||
op_def->set_name(proto.rename_to());
|
||||
RenameInDocs(proto.name(), proto.rename_to(), op_def);
|
||||
}
|
||||
for (const auto& attr_default : proto.attr_default()) {
|
||||
bool found = false;
|
||||
for (int i = 0; i < op_def->attr_size(); ++i) {
|
||||
if (op_def->attr(i).name() == attr_default.name()) {
|
||||
*op_def->mutable_attr(i)->mutable_default_value() =
|
||||
attr_default.value();
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
LOG(WARNING) << proto.name() << " can't find attr " << attr_default.name()
|
||||
<< " to override default";
|
||||
}
|
||||
}
|
||||
for (const auto& attr_rename : proto.attr_rename()) {
|
||||
bool found = false;
|
||||
for (int i = 0; i < op_def->attr_size(); ++i) {
|
||||
if (op_def->attr(i).name() == attr_rename.from()) {
|
||||
*op_def->mutable_attr(i)->mutable_name() = attr_rename.to();
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
RenameInDocs(attr_rename.from(), attr_rename.to(), op_def);
|
||||
} else {
|
||||
LOG(WARNING) << proto.name() << " can't find attr " << attr_rename.from()
|
||||
<< " to rename";
|
||||
}
|
||||
}
|
||||
for (const auto& input_rename : proto.input_rename()) {
|
||||
bool found = false;
|
||||
for (int i = 0; i < op_def->input_arg_size(); ++i) {
|
||||
if (op_def->input_arg(i).name() == input_rename.from()) {
|
||||
*op_def->mutable_input_arg(i)->mutable_name() = input_rename.to();
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
RenameInDocs(input_rename.from(), input_rename.to(), op_def);
|
||||
} else {
|
||||
LOG(WARNING) << proto.name() << " can't find input "
|
||||
<< input_rename.from() << " to rename";
|
||||
}
|
||||
}
|
||||
for (const auto& output_rename : proto.output_rename()) {
|
||||
bool found = false;
|
||||
for (int i = 0; i < op_def->output_arg_size(); ++i) {
|
||||
if (op_def->output_arg(i).name() == output_rename.from()) {
|
||||
*op_def->mutable_output_arg(i)->mutable_name() = output_rename.to();
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
RenameInDocs(output_rename.from(), output_rename.to(), op_def);
|
||||
} else {
|
||||
LOG(WARNING) << proto.name() << " can't find output "
|
||||
<< output_rename.from() << " to rename";
|
||||
}
|
||||
}
|
||||
|
||||
return &proto;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -17,7 +17,12 @@ limitations under the License.
|
||||
#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -34,6 +39,31 @@ string WordWrap(StringPiece prefix, StringPiece str, int width);
|
||||
// returns false.
|
||||
bool ConsumeEquals(StringPiece* description);
|
||||
|
||||
// Takes a list of files with OpGenOverrides text protos, and allows you to
|
||||
// look up the specific override for any given op.
|
||||
class OpGenOverrideMap {
|
||||
public:
|
||||
// `filenames` is a comma-separated list of file names. If an op
|
||||
// is mentioned in more than one file, the last one takes priority.
|
||||
Status LoadFileList(Env* env, const string& filenames);
|
||||
|
||||
// Load a single file. If more than one file is loaded, later ones
|
||||
// take priority for any ops in common.
|
||||
Status LoadFile(Env* env, const string& filename);
|
||||
|
||||
// Look up the override for `*op_def` from the loaded files, and
|
||||
// mutate `*op_def` to reflect the requested changes. Does not apply
|
||||
// 'skip', 'hide', or 'alias' overrides. Caller has to deal with
|
||||
// those since they can't be simulated by mutating `*op_def`.
|
||||
// Returns nullptr if op is not in any loaded file. Otherwise, the
|
||||
// pointer must not be referenced beyond the lifetime of *this or
|
||||
// the next file load.
|
||||
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<string, OpGenOverride> map_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
|
||||
|
67
tensorflow/core/framework/op_gen_overrides.proto
Normal file
67
tensorflow/core/framework/op_gen_overrides.proto
Normal file
@ -0,0 +1,67 @@
|
||||
// Defines the text format for adding per-op overrides for client
|
||||
// language op code generators.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
import "tensorflow/core/framework/attr_value.proto";
|
||||
|
||||
// Used to override the default API & behavior in the generated code
|
||||
// for client languages, from what you would get from the OpDef alone.
|
||||
// This is so we can evolve the API while remaining backwards
|
||||
// compatible when interpretting old graphs. Overrides go in an
|
||||
// "op_gen_overrides.pbtxt" file with a text-format OpGenOverrides
|
||||
// message. Right now these only apply to the C++ API.
|
||||
// TODO(josh11b): In the future there will be a common set of overrides
|
||||
// and per-client-language overrides.
|
||||
//
|
||||
// WARNING: Be *very* careful using these features -- these overrides
|
||||
// can change the semantics of existing code. These changes may need
|
||||
// to wait until a major release of TensorFlow to avoid breaking our
|
||||
// compatibility promises.
|
||||
message OpGenOverride {
|
||||
// Name of the op to apply overrides to.
|
||||
string name = 1;
|
||||
|
||||
// Do not include this op in the generated API.
|
||||
// If `skip` is true, all other overrides are ignored for this op.
|
||||
bool skip = 2;
|
||||
|
||||
// Hide this op by putting it into an internal namespace (or whatever
|
||||
// is appropriate in the target language).
|
||||
bool hide = 3;
|
||||
|
||||
// Use a different name in the API than the op's name. Note that
|
||||
// the op's name in `backticks` will also be replaced in the docs.
|
||||
string rename_to = 4;
|
||||
|
||||
// Create *additional* API endpoints with different names (contrast
|
||||
// with rename_to, which affects the original name).
|
||||
repeated string alias = 5;
|
||||
|
||||
// Map the name of an attr to a new default value to use. This
|
||||
// default will be used when creating new graphs, as opposed to the
|
||||
// default in the OpDef, which will be used when interpreting old
|
||||
// GraphDefs. If this attr is also renamed (using attr_rename
|
||||
// below), use the original name of the attr.
|
||||
message AttrDefault {
|
||||
string name = 1;
|
||||
AttrValue value = 2;
|
||||
}
|
||||
repeated AttrDefault attr_default = 6;
|
||||
|
||||
// Change the name used to access attrs/inputs/outputs in the API
|
||||
// from what is used in the GraphDef. Note that these names in
|
||||
// `backticks` will also be replaced in the docs.
|
||||
message Rename {
|
||||
string from = 1;
|
||||
string to = 2;
|
||||
}
|
||||
repeated Rename attr_rename = 7;
|
||||
repeated Rename input_rename = 8;
|
||||
repeated Rename output_rename = 9;
|
||||
}
|
||||
|
||||
message OpGenOverrides {
|
||||
repeated OpGenOverride op = 1;
|
||||
}
|
@ -1860,6 +1860,7 @@ values: The `k` largest elements along each last dimensional slice.
|
||||
indices: The indices of `values` within the last dimension of `input`.
|
||||
)doc");
|
||||
|
||||
// This is the same as `TopK`, but takes `k` as in input rather than an attr.
|
||||
REGISTER_OP("TopKV2")
|
||||
.Input("input: T")
|
||||
.Input("k: int32")
|
||||
@ -1882,8 +1883,6 @@ row (resp. vector along the last dimension). Thus,
|
||||
|
||||
If two elements are equal, the lower-index element appears first.
|
||||
|
||||
This is the same as `TopK`, but takes `k` as in input rather than an attr.
|
||||
|
||||
input: 1-D or higher with last dimension at least `k`.
|
||||
k: 0-D. Number of top elements to look for along the last dimension (along each
|
||||
row for matrices).
|
||||
|
@ -163,7 +163,7 @@ Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
string output_name = "top_k";
|
||||
TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels);
|
||||
TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
|
||||
// This runs the GraphDef network definition that we've just constructed, and
|
||||
// returns the results in the output tensors.
|
||||
tensorflow::GraphDef graph;
|
||||
|
@ -184,7 +184,7 @@ Status GetTopDetections(const std::vector<Tensor>& outputs, int how_many_labels,
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
string output_name = "top_k";
|
||||
TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels);
|
||||
TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
|
||||
// This runs the GraphDef network definition that we've just constructed, and
|
||||
// returns the results in the output tensors.
|
||||
tensorflow::GraphDef graph;
|
||||
|
@ -315,6 +315,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:protos_cc",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -143,6 +143,7 @@ def tf_gen_op_libs(op_lib_names, deps=None):
|
||||
def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="",
|
||||
op_gen="//tensorflow/cc:cc_op_gen_main",
|
||||
deps=None,
|
||||
override_file=None,
|
||||
include_internal_ops=0):
|
||||
# Construct an op generator binary for these ops.
|
||||
tool = out_ops_file + "_gen_cc"
|
||||
@ -156,12 +157,21 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="",
|
||||
deps = [op_gen] + deps
|
||||
)
|
||||
|
||||
if override_file == None:
|
||||
srcs = []
|
||||
override_arg = ","
|
||||
else:
|
||||
srcs = [override_file]
|
||||
override_arg = "$(location " + override_file + ")"
|
||||
native.genrule(
|
||||
name=name + "_genrule",
|
||||
outs=[out_ops_file + ".h", out_ops_file + ".cc"],
|
||||
outs=[out_ops_file + ".h", out_ops_file + ".cc",
|
||||
out_ops_file + "_internal.h", out_ops_file + "_internal.cc"],
|
||||
srcs=srcs,
|
||||
tools=[":" + tool],
|
||||
cmd=("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " +
|
||||
"$(location :" + out_ops_file + ".cc) " + str(include_internal_ops)))
|
||||
"$(location :" + out_ops_file + ".cc) " + override_arg + " " +
|
||||
str(include_internal_ops)))
|
||||
|
||||
# Given a list of "op_lib_names" (a list of files in the ops directory
|
||||
# without their .cc extensions), generate individual C++ .cc and .h
|
||||
@ -181,6 +191,15 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="",
|
||||
# hdrs = [ "ops/array_ops.h",
|
||||
# "ops/math_ops.h" ],
|
||||
# deps = [ ... ])
|
||||
#
|
||||
# Plus a private library for the "hidden" ops.
|
||||
# cc_library(name = "tf_ops_lib_internal",
|
||||
# srcs = [ "ops/array_ops_internal.cc",
|
||||
# "ops/math_ops_internal.cc" ],
|
||||
# hdrs = [ "ops/array_ops_internal.h",
|
||||
# "ops/math_ops_internal.h" ],
|
||||
# deps = [ ... ])
|
||||
# TODO(josh11b): Cleaner approach for hidden ops.
|
||||
def tf_gen_op_wrappers_cc(name,
|
||||
op_lib_names=[],
|
||||
other_srcs=[],
|
||||
@ -192,16 +211,21 @@ def tf_gen_op_wrappers_cc(name,
|
||||
"//tensorflow/cc:const_op",
|
||||
],
|
||||
op_gen="//tensorflow/cc:cc_op_gen_main",
|
||||
override_file=None,
|
||||
include_internal_ops=0,
|
||||
visibility=None):
|
||||
subsrcs = other_srcs
|
||||
subhdrs = other_hdrs
|
||||
internalsrcs = []
|
||||
internalhdrs = []
|
||||
for n in op_lib_names:
|
||||
tf_gen_op_wrapper_cc(
|
||||
n, "ops/" + n, pkg=pkg, op_gen=op_gen,
|
||||
n, "ops/" + n, pkg=pkg, op_gen=op_gen, override_file=override_file,
|
||||
include_internal_ops=include_internal_ops)
|
||||
subsrcs += ["ops/" + n + ".cc"]
|
||||
subhdrs += ["ops/" + n + ".h"]
|
||||
internalsrcs += ["ops/" + n + "_internal.cc"]
|
||||
internalhdrs += ["ops/" + n + "_internal.h"]
|
||||
|
||||
native.cc_library(name=name,
|
||||
srcs=subsrcs,
|
||||
@ -217,6 +241,20 @@ def tf_gen_op_wrappers_cc(name,
|
||||
copts=tf_copts(),
|
||||
alwayslink=1,
|
||||
visibility=visibility)
|
||||
native.cc_library(name=name + "_internal",
|
||||
srcs=internalsrcs,
|
||||
hdrs=internalhdrs,
|
||||
deps=deps + if_not_android([
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]) + if_android([
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
]),
|
||||
copts=tf_copts(),
|
||||
alwayslink=1,
|
||||
visibility=["//visibility:private"])
|
||||
|
||||
# Invoke this rule in .../tensorflow/python to build the wrapper library.
|
||||
def tf_gen_op_wrapper_py(name, out=None, hidden=None, visibility=None, deps=[],
|
||||
|
@ -296,7 +296,7 @@ class QuantizeNodesTest : public ::testing::Test {
|
||||
Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor));
|
||||
|
||||
Output concat_op =
|
||||
Concat(root.WithOpName("concat_op"), shape_op, {a_op, b_op});
|
||||
Concat(root.WithOpName("concat_op"), {a_op, b_op}, shape_op);
|
||||
|
||||
GraphDef float_graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
|
||||
|
Loading…
Reference in New Issue
Block a user