diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 38117d388fc..ad787ab0243 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -228,6 +228,7 @@ cc_library( srcs = ["gradients/array_grad.cc"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":gradients", ], @@ -239,6 +240,7 @@ tf_cc_test( deps = [ ":array_grad", ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":grad_testutil", ":gradient_checker", @@ -334,6 +336,7 @@ tf_gen_op_wrappers_cc( "ops/const_op.h", "ops/standard_ops.h", ], + override_file = "ops/op_gen_overrides.pbtxt", pkg = "//tensorflow/core", ) @@ -387,6 +390,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 2efa82b3b42..5f85d8c5edf 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -15,6 +15,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include <vector> #include "tensorflow/cc/framework/cc_op_gen.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -38,7 +39,7 @@ const int kRightMargin = 79; // Converts: // bazel-out/.../genfiles/(external/YYY/)?XX // to: XX. -string GetPath(const std::string& dot_h_fname) { +string GetPath(const string& dot_h_fname) { auto pos = dot_h_fname.find("/genfiles/"); string result = dot_h_fname; if (pos != string::npos) { @@ -60,7 +61,7 @@ string GetPath(const std::string& dot_h_fname) { // cc/ops/gen_foo_ops.h // to: // CC_OPS_GEN_FOO_OPS_H_ -string ToGuard(const std::string& path) { +string ToGuard(const string& path) { string guard; guard.reserve(path.size() + 1); // + 1 -> trailing _ for (const char c : path) { @@ -360,7 +361,12 @@ bool HasOptionalAttrs( } struct OpInfo { - explicit OpInfo(const OpDef& op_def); + // graph_op_def: The OpDef used by the runtime, has the names that + // must be used when calling NodeBuilder. + // interface_op_def: The OpDef used in the interface in the generated + // code, with possibly overridden names and defaults. + explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def, + const std::vector<string>& aliases); string GetOpAttrStruct() const; string GetConstructorDecl(StringPiece op_name_prefix, bool include_attr) const; @@ -378,11 +384,15 @@ struct OpInfo { bool has_optional_attrs; string comment; + const OpDef& graph_op_def; const OpDef& op_def; + const std::vector<string>& aliases; std::unordered_map<string, string> inferred_input_attrs; }; -OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) { +OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, + const std::vector<string>& a) + : graph_op_def(g_op_def), op_def(i_op_def), aliases(a) { op_name = op_def.name(); InferOpAttributes(op_def, &inferred_input_attrs); has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs); @@ -443,7 +453,6 @@ OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) { strings::StrAppend(&comment, " ", attr.description(), "\n"); } } - comment = MakeComment(comment, ""); for (int i = 0; i < op_def.output_arg_size(); ++i) { const auto& arg = op_def.output_arg(i); @@ -453,13 +462,55 @@ OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) { output_names.push_back(AvoidCPPKeywords(arg.name())); is_list_output.push_back(is_list); } + + strings::StrAppend(&comment, "\nReturns:\n"); + if (op_def.output_arg_size() == 0) { // No outputs. + strings::StrAppend(&comment, "* the created `Operation`\n"); + } else if (op_def.output_arg_size() == 1) { // One output + if (is_list_output[0]) { + strings::StrAppend(&comment, "* `OutputList`: "); + } else { + strings::StrAppend(&comment, "* `Output`: "); + } + if (op_def.output_arg(0).description().empty()) { + strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(), + " tensor.\n"); + } else { + // TODO(josh11b): Word wrap this. + strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n"); + } + } else { // Multiple outputs. + for (int i = 0; i < op_def.output_arg_size(); ++i) { + if (is_list_output[i]) { + strings::StrAppend(&comment, "* `OutputList`"); + } else { + strings::StrAppend(&comment, "* `Output`"); + } + strings::StrAppend(&comment, " ", output_names[i]); + if (op_def.output_arg(i).description().empty()) { + strings::StrAppend(&comment, "\n"); + } else { + // TODO(josh11b): Word wrap this. + strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(), + "\n"); + } + } + } + + if (!aliases.empty()) { + strings::StrAppend(&comment, "\nAliases:\n"); + for (const auto& alias : aliases) { + strings::StrAppend(&comment, "* ", alias, "\n"); + } + } + comment = MakeComment(comment, ""); } string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; - string attrs_comment = strings::StrCat("Optional attribute setters for ", - op_def.name(), " :\n\n"); + string attrs_comment = + strings::StrCat("Optional attribute setters for ", op_name, " :\n\n"); for (int i = 0; i < op_def.attr_size(); ++i) { const auto& attr(op_def.attr(i)); @@ -603,7 +654,13 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { ";\n"); } - strings::StrAppend(&class_decl, "};\n\n"); + strings::StrAppend(&class_decl, "};\n"); + if (!aliases.empty()) { + for (const auto& alias : aliases) { + strings::StrAppend(&class_decl, "typedef ", op_name, " ", alias, ";\n"); + } + } + strings::StrAppend(&class_decl, "\n"); TF_CHECK_OK(h->Append(class_decl)); } @@ -642,7 +699,7 @@ void OpInfo::GetOutput(string* out) const { for (int i = 0; i < op_def.output_arg_size(); ++i) { const string arg_range = strings::StrCat( - "_outputs_range[\"", op_def.output_arg(i).name(), "\"]"); + "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]"); if (is_list_output[i]) { strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ", arg_range, ".second; ++i)\n"); @@ -673,17 +730,18 @@ string OpInfo::GetConstructorBody() const { } strings::StrAppend(&body, " ::tensorflow::Node* ret;\n"); - strings::StrAppend(&body, " const auto unique_name = ", scope_str, - ".GetUniqueNameForOp(\"", op_def.name(), "\");\n"); + strings::StrAppend(&body, " const auto unique_name = ", scope_str, + ".GetUniqueNameForOp(\"", op_name, "\");\n"); strings::StrAppend( &body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"", - op_def.name(), "\")\n"); + graph_op_def.name(), "\")\n"); const string spaces = " "; for (int i = 0; i < op_def.input_arg_size(); ++i) { const auto& arg(op_def.input_arg(i)); strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n"); } for (int i = 0; i < op_def.attr_size(); ++i) { + const auto& graph_attr(graph_op_def.attr(i)); const auto& attr(op_def.attr(i)); if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) { continue; @@ -691,7 +749,7 @@ string OpInfo::GetConstructorBody() const { const string attr_name = attr.has_default_value() ? strings::StrCat("attrs.", attr.name(), "_") : AvoidCPPKeywords(attr.name()); - strings::StrAppend(&body, spaces, ".Attr(\"", attr.name(), "\", ", + strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ", attr_name, ")\n"); } strings::StrAppend(&body, " ;\n"); @@ -736,23 +794,17 @@ void OpInfo::WriteClassDef(WritableFile* cc) const { TF_CHECK_OK(cc->Append(class_def)); } -void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) { - OpInfo op_info(op_def); +void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def, + const std::vector<string>& aliases, WritableFile* h, + WritableFile* cc) { + OpInfo op_info(graph_op_def, interface_op_def, aliases); op_info.WriteClassDecl(h); op_info.WriteClassDef(cc); } -} // namespace - -void WriteCCOps(const OpList& ops, const std::string& dot_h_fname, - const std::string& dot_cc_fname) { - Env* env = Env::Default(); - std::unique_ptr<WritableFile> h = nullptr; - std::unique_ptr<WritableFile> cc = nullptr; - TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h)); - TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc)); - +void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h, + WritableFile* cc, string* op_header_guard) { const string header = R"header(// This file is MACHINE GENERATED! Do not edit. @@ -765,18 +817,22 @@ void WriteCCOps(const OpList& ops, const std::string& dot_h_fname, )header"; // TODO(keveman): Make namespaces configurable. - const string namespace_begin = R"namespace( + const string namespace_begin = internal ? R"namespace( +namespace tensorflow { +namespace ops { +namespace internal { +// NOTE: This namespace has internal TensorFlow details that +// are not part of TensorFlow's public API. + +)namespace" + : R"namespace( namespace tensorflow { namespace ops { )namespace"; - const string footer = R"footer(} // namespace ops -} // namespace tensorflow -)footer"; - const string op_header = GetPath(dot_h_fname); - const string op_header_guard = ToGuard(op_header); + *op_header_guard = ToGuard(op_header); const string cc_header = strings::StrCat( R"include(// This file is MACHINE GENERATED! Do not edit. @@ -788,27 +844,25 @@ namespace ops { TF_CHECK_OK(h->Append( strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n" "#ifndef ", - op_header_guard, + *op_header_guard, "\n" "#define ", - op_header_guard, "\n\n"))); + *op_header_guard, "\n\n"))); TF_CHECK_OK(h->Append(header)); TF_CHECK_OK(h->Append(namespace_begin)); TF_CHECK_OK(cc->Append(cc_header)); +} - for (const auto& op_def : ops.op()) { - // Skip deprecated ops. - // TODO(josh11b): If needed, can put them into a "deprecated" namespace - // instead of skipping. - if (op_def.has_deprecation() && - op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) { - continue; - } - // We use a hand-written wrapper for "Const", since the generated - // code depends on it. - if (op_def.name() == "Const") continue; - WriteCCOp(op_def, h.get(), cc.get()); - } +void FinishFiles(bool internal, WritableFile* h, WritableFile* cc, + const string& op_header_guard) { + const string footer = internal ? R"footer(} // namespace internal +} // namespace ops +} // namespace tensorflow +)footer" + : + R"footer(} // namespace ops +} // namespace tensorflow +)footer"; TF_CHECK_OK(h->Append(footer)); TF_CHECK_OK( @@ -819,4 +873,82 @@ namespace ops { TF_CHECK_OK(h->Close()); } +string MakeInternal(const string& fname) { + auto dot_pos = fname.rfind('.'); + if (dot_pos == string::npos) { + return strings::StrCat(fname, "_internal"); + } else { + return strings::StrCat(fname.substr(0, dot_pos), "_internal", + fname.substr(dot_pos)); + } +} + +} // namespace + +void WriteCCOps(const OpList& ops, const string& dot_h_fname, + const string& dot_cc_fname, const string& overrides_fnames) { + Env* env = Env::Default(); + + // Load the override map. + OpGenOverrideMap override_map; + if (!overrides_fnames.empty()) { + override_map.LoadFileList(env, overrides_fnames); + } + + // Write the initial boilerplate to the .h and .cc files. + std::unique_ptr<WritableFile> h = nullptr; + std::unique_ptr<WritableFile> cc = nullptr; + TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h)); + TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc)); + string op_header_guard; + StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard); + + // Create the internal versions of these files for the hidden ops. + std::unique_ptr<WritableFile> internal_h = nullptr; + std::unique_ptr<WritableFile> internal_cc = nullptr; + const string internal_dot_h_fname = MakeInternal(dot_h_fname); + TF_CHECK_OK(env->NewWritableFile(internal_dot_h_fname, &internal_h)); + TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc)); + string internal_op_header_guard; + StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(), + internal_cc.get(), &internal_op_header_guard); + + for (const auto& graph_op_def : ops.op()) { + // Skip deprecated ops. + // TODO(josh11b): If needed, can put them into a "deprecated" namespace + // instead of skipping. + if (graph_op_def.has_deprecation() && + graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) { + continue; + } + + // We use a hand-written wrapper for "Const", since the generated + // code depends on it. + if (graph_op_def.name() == "Const") continue; + + // Incorporate overrides from override_map. + OpDef interface_op_def = graph_op_def; + const OpGenOverride* op_override = + override_map.ApplyOverride(&interface_op_def); + std::vector<string> aliases; + if (op_override) { + if (op_override->skip()) continue; + aliases.assign(op_override->alias().begin(), op_override->alias().end()); + if (op_override->hide()) { + // Write hidden ops to _internal.h and _internal.cc. + WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(), + internal_cc.get()); + continue; + } + } + + // This isn't a hidden op, write it to the main files. + WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get()); + } + + FinishFiles(false, h.get(), cc.get(), op_header_guard); + FinishFiles(true /* internal */, internal_h.get(), internal_cc.get(), + internal_op_header_guard); +} + } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index 3d35d0ef32b..fa5e004f031 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -17,12 +17,13 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { /// Result is written to files dot_h and dot_cc. -void WriteCCOps(const OpList& ops, const std::string& dot_h_fname, - const std::string& dot_cc_fname); +void WriteCCOps(const OpList& ops, const string& dot_h_fname, + const string& dot_cc_fname, const string& overrides_fnames); } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index 6c5a4beb217..3b80cf993eb 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -24,10 +24,10 @@ namespace tensorflow { namespace { void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, - bool include_internal) { + const std::string& overrides_fnames, bool include_internal) { OpList ops; OpRegistry::Global()->Export(include_internal, &ops); - WriteCCOps(ops, dot_h, dot_cc); + WriteCCOps(ops, dot_h, dot_cc, overrides_fnames); } } // namespace @@ -35,15 +35,18 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); - if (argc != 4) { + if (argc != 5) { + for (int i = 1; i < argc; ++i) { + fprintf(stderr, "Arg %d = %s\n", i, argv[i]); + } fprintf(stderr, - "Usage: %s out.h out.cc include_internal\n" + "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n" " include_internal: 1 means include internal ops\n", argv[0]); exit(1); } - bool include_internal = tensorflow::StringPiece("1") == argv[3]; - tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal); + bool include_internal = tensorflow::StringPiece("1") == argv[4]; + tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal); return 0; } diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 1a055bf6257..6dc0d84c16d 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -68,7 +68,7 @@ TEST(CCOpTest, Attrs) { TEST(CCOpTest, SplitConcat) { Scope root = Scope::NewRootScope(); Split p(root, 0, {{1}, {2}}, 2); - auto c = Concat(root, 0, {p[0], p[1]}); + auto c = Concat(root, {p[0], p[1]}, 0); TF_EXPECT_OK(root.status()); Tensor out; test::GetTensor(root, c, &out); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 178bb001d07..3b0e563986b 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -15,6 +15,7 @@ limitations under the License. #include <vector> +#include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/framework/grad_op_registry.h" @@ -92,7 +93,7 @@ Status SplitGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { grad_outputs->push_back(NoGradient()); - grad_outputs->push_back(Concat(scope, op.input(0), grad_inputs)); + grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0))); return scope.status(); } REGISTER_GRADIENT_OP("Split", SplitGrad); @@ -219,7 +220,7 @@ Status ReverseGrad(const Scope& scope, const Operation& op, grad_outputs->push_back(NoGradient()); return scope.status(); } -REGISTER_GRADIENT_OP("Reverse", ReverseGrad); +REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad); Status ScatterNdGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, @@ -319,8 +320,8 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op, std::vector<Output>* grad_outputs) { string mode; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); - grad_outputs->push_back( - tensorflow::ops::MirrorPadGrad(scope, grad_inputs[0], op.input(1), mode)); + grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( + scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); return scope.status(); } diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 6f5d9885a5e..0cdb3b7a1fb 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -17,12 +17,14 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { using namespace ops; // NOLINT(build/namespaces) +using ops::internal::MirrorPadGrad; namespace { @@ -207,8 +209,7 @@ TEST_F(ArrayGradTest, ReverseSequenceGrad) { TEST_F(ArrayGradTest, ReverseGrad) { TensorShape shape({5, 2, 5}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto reverse_dims = Const(scope_, {true, false, true}); - auto y = Reverse(scope_, x, reverse_dims); + auto y = Reverse(scope_, x, {0, 2}); RunTest(x, shape, y, shape); } diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt new file mode 100644 index 00000000000..79d3a04012a --- /dev/null +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -0,0 +1,42 @@ +# array_ops +op { name: "BroadcastArgs" rename_to: "BroadcastDynamicShape" } +op { name: "ConcatOffset" skip: true } # Maybe should just be hidden? +op { name: "Concat" skip: true } +op { name: "ConcatV2" rename_to: "Concat" } +op { name: "MirrorPadGrad" hide: true } +op { name: "Reverse" skip: true } +op { name: "ReverseV2" rename_to: "Reverse" } + +# candidate_sampling_ops +# control_flow_ops +# ctc_ops +# data_flow_ops +op { name: "FakeQueue" skip: true } + +# functional_ops +# image_ops +# io_ops +# linalg_ops +# logging_ops +# math_ops +op { name: "All" alias: "ReduceAll" input_rename: { from: "reduction_indices" to: "axis" } } +op { name: "Any" alias: "ReduceAny" input_rename: { from: "reduction_indices" to: "axis" } } +op { name: "Max" alias: "ReduceMax" input_rename: { from: "reduction_indices" to: "axis" } } +op { name: "Mean" alias: "ReduceMean" input_rename: { from: "reduction_indices" to: "axis" } } +op { name: "Min" alias: "ReduceMin" input_rename: { from: "reduction_indices" to: "axis" } } +op { name: "Prod" alias: "ReduceProd" input_rename: { from: "reduction_indices" to: "axis" } } +op { name: "Sum" alias: "ReduceSum" input_rename: { from: "reduction_indices" to: "axis" } } + +# nn_ops +op { name: "TopKV2" rename_to: "TopK" alias: "TopKV2" } # TODO(josh11b): delete "TopKV2" alias once users updated + +# parsing_ops +# random_ops +# script_ops +# sdca_ops +# state_ops +# sparse_ops +# string_ops +# user_ops +# training_ops +# word2vec deprecated ops diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index bb1d8d387b4..bee41889f51 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -454,6 +454,7 @@ $(wildcard tensorflow/core/*/*/*testutil*) \ $(wildcard tensorflow/core/*/*/*testlib*) \ $(wildcard tensorflow/core/*/*/*main.cc) \ $(wildcard tensorflow/core/debug/*.cc) \ +$(wildcard tensorflow/core/framework/op_gen_lib.cc) \ $(wildcard tensorflow/core/graph/dot.*) \ $(wildcard tensorflow/core/lib/gif/*) \ $(wildcard tensorflow/core/lib/io/zlib*) \ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7f251b557cf..6d8971e90ba 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -328,7 +328,6 @@ tf_cuda_library( "framework/op.h", "framework/op_def_builder.h", "framework/op_def_util.h", - "framework/op_gen_lib.h", "framework/op_kernel.h", "framework/partial_tensor_shape.h", "framework/queue_interface.h", @@ -384,6 +383,25 @@ tf_cuda_library( deps = [":framework_internal"], ) +tf_proto_library_cc( + name = "op_gen_overrides_proto", + srcs = ["framework/op_gen_overrides.proto"], + cc_api_version = 2, + protodeps = [":protos_all"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "op_gen_lib", + srcs = ["framework/op_gen_lib.cc"], + hdrs = ["framework/op_gen_lib.h"], + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":op_gen_overrides_proto_cc", + ], +) + cc_library( name = "framework_lite", srcs = tf_additional_minimal_lib_srcs(), @@ -761,6 +779,7 @@ filegroup( "**/*testlib*", "**/*main.cc", "debug/**/*", + "framework/op_gen_*", "graph/dot.*", "lib/jpeg/**/*", "lib/png/**/*", @@ -1250,6 +1269,7 @@ tf_cuda_library( "util/reporter.h", "util/reporter.cc", "framework/fake_input.*", + "framework/op_gen_lib.*", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", "util/version_info.cc", diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index e1500ed1ad9..dc1272c5d68 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -397,18 +397,22 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, } } *result = target_context->MakeShape(dims); - } else if (src_op == "Concat") { + } else if (src_op == "Concat" || src_op == "ConcatV2") { *result = target_context->Scalar(); + // For Concat, input 0 is concat dim; for V2 it is the last input. + const int concat_dim = + src_op == "Concat" ? 0 : src_context->num_inputs() - 1; // Concat is concatenating its input shape vectors. - // input 0 is ignored as it is the concat dim and will always be 0. - for (int i = 1; i < src_context->num_inputs(); ++i) { + for (int i = 0; i < src_context->num_inputs(); ++i) { + // Concat dim is ignored (and will always be a scalar). + if (i == concat_dim) continue; ShapeHandle sub_result; TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(), i, &sub_result)); if (!target_context->RankKnown(sub_result)) { // Failed to evaluate. Treat the output as completely unknown. - // TODO(cwhipkey): we could rely on all inputs being the same size, so - // figure that size out and append the right number of unknown dims. + // TODO(cwhipkey): we could rely on all inputs being the same rank, so + // figure that rank out and append the right number of unknown dims. *result = target_context->UnknownShape(); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index f7d5a9cfc9d..1515808a63d 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -642,7 +642,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) { const_input, }; // clang-format on auto concat_dim = ops::Const(root, 0); - auto concat = ops::Concat(root, concat_dim, concat_inputs); + auto concat = ops::Concat(root, concat_inputs, concat_dim); TF_ASSERT_OK(root.status()); Node* result; @@ -684,7 +684,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { ops::Shape(root, Output(unknown)), }; // clang-format on auto concat_dim = ops::Const(root, 0); - auto concat = ops::Concat(root, concat_dim, concat_inputs); + auto concat = ops::Concat(root, concat_inputs, concat_dim); TF_ASSERT_OK(root.status()); Node* result; @@ -726,7 +726,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { const_input, }; // clang-format on auto concat_dim = ops::Const(root, 0); - auto concat = ops::Concat(root, concat_dim, concat_inputs); + auto concat = ops::Concat(root, concat_inputs, concat_dim); TF_ASSERT_OK(root.status()); Node* result; diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 9b40e3e1367..95ad0d4bec8 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" +#include <vector> +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -67,4 +70,136 @@ bool ConsumeEquals(StringPiece* description) { return false; } +Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) { + std::vector<string> v = str_util::Split(filenames, ","); + for (const string& f : v) { + TF_RETURN_IF_ERROR(LoadFile(env, f)); + } + return Status::OK(); +} + +Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) { + if (filename.empty()) return Status::OK(); + string contents; + TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); + OpGenOverrides all; + protobuf::TextFormat::ParseFromString(contents, &all); + for (const auto& one : all.op()) { + map_[one.name()] = one; + } + return Status::OK(); +} + +static void StringReplace(const string& from, const string& to, string* s) { + std::vector<string> rest = str_util::Split(*s, from); + *s = str_util::Join(rest, to.c_str()); +} + +static void RenameInDocs(const string& from, const string& to, OpDef* op_def) { + const string from_quoted = strings::StrCat("`", from, "`"); + const string to_quoted = strings::StrCat("`", to, "`"); + for (int i = 0; i < op_def->input_arg_size(); ++i) { + if (!op_def->input_arg(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + op_def->mutable_input_arg(i)->mutable_description()); + } + } + for (int i = 0; i < op_def->output_arg_size(); ++i) { + if (!op_def->output_arg(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + op_def->mutable_output_arg(i)->mutable_description()); + } + } + for (int i = 0; i < op_def->attr_size(); ++i) { + if (!op_def->attr(i).description().empty()) { + StringReplace(from_quoted, to_quoted, + op_def->mutable_attr(i)->mutable_description()); + } + } + if (!op_def->summary().empty()) { + StringReplace(from_quoted, to_quoted, op_def->mutable_summary()); + } + if (!op_def->description().empty()) { + StringReplace(from_quoted, to_quoted, op_def->mutable_summary()); + } +} + +const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const { + // Look up + const auto iter = map_.find(op_def->name()); + if (iter == map_.end()) return nullptr; + const OpGenOverride& proto = iter->second; + + // Apply overrides from `proto`. + if (!proto.rename_to().empty()) { + op_def->set_name(proto.rename_to()); + RenameInDocs(proto.name(), proto.rename_to(), op_def); + } + for (const auto& attr_default : proto.attr_default()) { + bool found = false; + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == attr_default.name()) { + *op_def->mutable_attr(i)->mutable_default_value() = + attr_default.value(); + found = true; + break; + } + } + if (!found) { + LOG(WARNING) << proto.name() << " can't find attr " << attr_default.name() + << " to override default"; + } + } + for (const auto& attr_rename : proto.attr_rename()) { + bool found = false; + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == attr_rename.from()) { + *op_def->mutable_attr(i)->mutable_name() = attr_rename.to(); + found = true; + break; + } + } + if (found) { + RenameInDocs(attr_rename.from(), attr_rename.to(), op_def); + } else { + LOG(WARNING) << proto.name() << " can't find attr " << attr_rename.from() + << " to rename"; + } + } + for (const auto& input_rename : proto.input_rename()) { + bool found = false; + for (int i = 0; i < op_def->input_arg_size(); ++i) { + if (op_def->input_arg(i).name() == input_rename.from()) { + *op_def->mutable_input_arg(i)->mutable_name() = input_rename.to(); + found = true; + break; + } + } + if (found) { + RenameInDocs(input_rename.from(), input_rename.to(), op_def); + } else { + LOG(WARNING) << proto.name() << " can't find input " + << input_rename.from() << " to rename"; + } + } + for (const auto& output_rename : proto.output_rename()) { + bool found = false; + for (int i = 0; i < op_def->output_arg_size(); ++i) { + if (op_def->output_arg(i).name() == output_rename.from()) { + *op_def->mutable_output_arg(i)->mutable_name() = output_rename.to(); + found = true; + break; + } + } + if (found) { + RenameInDocs(output_rename.from(), output_rename.to(), op_def); + } else { + LOG(WARNING) << proto.name() << " can't find output " + << output_rename.from() << " to rename"; + } + } + + return &proto; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h index 83ead50a6ad..e92dc8d9241 100644 --- a/tensorflow/core/framework/op_gen_lib.h +++ b/tensorflow/core/framework/op_gen_lib.h @@ -17,7 +17,12 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ #include <string> +#include <unordered_map> +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/env.h" namespace tensorflow { @@ -34,6 +39,31 @@ string WordWrap(StringPiece prefix, StringPiece str, int width); // returns false. bool ConsumeEquals(StringPiece* description); +// Takes a list of files with OpGenOverrides text protos, and allows you to +// look up the specific override for any given op. +class OpGenOverrideMap { + public: + // `filenames` is a comma-separated list of file names. If an op + // is mentioned in more than one file, the last one takes priority. + Status LoadFileList(Env* env, const string& filenames); + + // Load a single file. If more than one file is loaded, later ones + // take priority for any ops in common. + Status LoadFile(Env* env, const string& filename); + + // Look up the override for `*op_def` from the loaded files, and + // mutate `*op_def` to reflect the requested changes. Does not apply + // 'skip', 'hide', or 'alias' overrides. Caller has to deal with + // those since they can't be simulated by mutating `*op_def`. + // Returns nullptr if op is not in any loaded file. Otherwise, the + // pointer must not be referenced beyond the lifetime of *this or + // the next file load. + const OpGenOverride* ApplyOverride(OpDef* op_def) const; + + private: + std::unordered_map<string, OpGenOverride> map_; +}; + } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ diff --git a/tensorflow/core/framework/op_gen_overrides.proto b/tensorflow/core/framework/op_gen_overrides.proto new file mode 100644 index 00000000000..8e66d39a7c7 --- /dev/null +++ b/tensorflow/core/framework/op_gen_overrides.proto @@ -0,0 +1,67 @@ +// Defines the text format for adding per-op overrides for client +// language op code generators. + +syntax = "proto3"; + +package tensorflow; +import "tensorflow/core/framework/attr_value.proto"; + +// Used to override the default API & behavior in the generated code +// for client languages, from what you would get from the OpDef alone. +// This is so we can evolve the API while remaining backwards +// compatible when interpretting old graphs. Overrides go in an +// "op_gen_overrides.pbtxt" file with a text-format OpGenOverrides +// message. Right now these only apply to the C++ API. +// TODO(josh11b): In the future there will be a common set of overrides +// and per-client-language overrides. +// +// WARNING: Be *very* careful using these features -- these overrides +// can change the semantics of existing code. These changes may need +// to wait until a major release of TensorFlow to avoid breaking our +// compatibility promises. +message OpGenOverride { + // Name of the op to apply overrides to. + string name = 1; + + // Do not include this op in the generated API. + // If `skip` is true, all other overrides are ignored for this op. + bool skip = 2; + + // Hide this op by putting it into an internal namespace (or whatever + // is appropriate in the target language). + bool hide = 3; + + // Use a different name in the API than the op's name. Note that + // the op's name in `backticks` will also be replaced in the docs. + string rename_to = 4; + + // Create *additional* API endpoints with different names (contrast + // with rename_to, which affects the original name). + repeated string alias = 5; + + // Map the name of an attr to a new default value to use. This + // default will be used when creating new graphs, as opposed to the + // default in the OpDef, which will be used when interpreting old + // GraphDefs. If this attr is also renamed (using attr_rename + // below), use the original name of the attr. + message AttrDefault { + string name = 1; + AttrValue value = 2; + } + repeated AttrDefault attr_default = 6; + + // Change the name used to access attrs/inputs/outputs in the API + // from what is used in the GraphDef. Note that these names in + // `backticks` will also be replaced in the docs. + message Rename { + string from = 1; + string to = 2; + } + repeated Rename attr_rename = 7; + repeated Rename input_rename = 8; + repeated Rename output_rename = 9; +} + +message OpGenOverrides { + repeated OpGenOverride op = 1; +} diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 02440bd6262..2d757b7a17d 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1860,6 +1860,7 @@ values: The `k` largest elements along each last dimensional slice. indices: The indices of `values` within the last dimension of `input`. )doc"); +// This is the same as `TopK`, but takes `k` as in input rather than an attr. REGISTER_OP("TopKV2") .Input("input: T") .Input("k: int32") @@ -1882,8 +1883,6 @@ row (resp. vector along the last dimension). Thus, If two elements are equal, the lower-index element appears first. -This is the same as `TopK`, but takes `k` as in input rather than an attr. - input: 1-D or higher with last dimension at least `k`. k: 0-D. Number of top elements to look for along the last dimension (along each row for matrices). diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 08e6e4544a2..fa024010281 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -163,7 +163,7 @@ Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels, using namespace ::tensorflow::ops; // NOLINT(build/namespaces) string output_name = "top_k"; - TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels); + TopK(root.WithOpName(output_name), outputs[0], how_many_labels); // This runs the GraphDef network definition that we've just constructed, and // returns the results in the output tensors. tensorflow::GraphDef graph; diff --git a/tensorflow/examples/multibox_detector/main.cc b/tensorflow/examples/multibox_detector/main.cc index 42972078e53..0d6875671b1 100644 --- a/tensorflow/examples/multibox_detector/main.cc +++ b/tensorflow/examples/multibox_detector/main.cc @@ -184,7 +184,7 @@ Status GetTopDetections(const std::vector<Tensor>& outputs, int how_many_labels, using namespace ::tensorflow::ops; // NOLINT(build/namespaces) string output_name = "top_k"; - TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels); + TopK(root.WithOpName(output_name), outputs[0], how_many_labels); // This runs the GraphDef network definition that we've just constructed, and // returns the results in the output tensors. tensorflow::GraphDef graph; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2de7cc83829..5649e4aebae 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -315,6 +315,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_cc", ], alwayslink = 1, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 0e5b39af10d..2977cc48eb3 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -143,6 +143,7 @@ def tf_gen_op_libs(op_lib_names, deps=None): def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", op_gen="//tensorflow/cc:cc_op_gen_main", deps=None, + override_file=None, include_internal_ops=0): # Construct an op generator binary for these ops. tool = out_ops_file + "_gen_cc" @@ -156,12 +157,21 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", deps = [op_gen] + deps ) + if override_file == None: + srcs = [] + override_arg = "," + else: + srcs = [override_file] + override_arg = "$(location " + override_file + ")" native.genrule( name=name + "_genrule", - outs=[out_ops_file + ".h", out_ops_file + ".cc"], + outs=[out_ops_file + ".h", out_ops_file + ".cc", + out_ops_file + "_internal.h", out_ops_file + "_internal.cc"], + srcs=srcs, tools=[":" + tool], cmd=("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " + - "$(location :" + out_ops_file + ".cc) " + str(include_internal_ops))) + "$(location :" + out_ops_file + ".cc) " + override_arg + " " + + str(include_internal_ops))) # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate individual C++ .cc and .h @@ -181,6 +191,15 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="", # hdrs = [ "ops/array_ops.h", # "ops/math_ops.h" ], # deps = [ ... ]) +# +# Plus a private library for the "hidden" ops. +# cc_library(name = "tf_ops_lib_internal", +# srcs = [ "ops/array_ops_internal.cc", +# "ops/math_ops_internal.cc" ], +# hdrs = [ "ops/array_ops_internal.h", +# "ops/math_ops_internal.h" ], +# deps = [ ... ]) +# TODO(josh11b): Cleaner approach for hidden ops. def tf_gen_op_wrappers_cc(name, op_lib_names=[], other_srcs=[], @@ -192,16 +211,21 @@ def tf_gen_op_wrappers_cc(name, "//tensorflow/cc:const_op", ], op_gen="//tensorflow/cc:cc_op_gen_main", + override_file=None, include_internal_ops=0, visibility=None): subsrcs = other_srcs subhdrs = other_hdrs + internalsrcs = [] + internalhdrs = [] for n in op_lib_names: tf_gen_op_wrapper_cc( - n, "ops/" + n, pkg=pkg, op_gen=op_gen, + n, "ops/" + n, pkg=pkg, op_gen=op_gen, override_file=override_file, include_internal_ops=include_internal_ops) subsrcs += ["ops/" + n + ".cc"] subhdrs += ["ops/" + n + ".h"] + internalsrcs += ["ops/" + n + "_internal.cc"] + internalhdrs += ["ops/" + n + "_internal.h"] native.cc_library(name=name, srcs=subsrcs, @@ -217,6 +241,20 @@ def tf_gen_op_wrappers_cc(name, copts=tf_copts(), alwayslink=1, visibility=visibility) + native.cc_library(name=name + "_internal", + srcs=internalsrcs, + hdrs=internalhdrs, + deps=deps + if_not_android([ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ]) + if_android([ + "//tensorflow/core:android_tensorflow_lib", + ]), + copts=tf_copts(), + alwayslink=1, + visibility=["//visibility:private"]) # Invoke this rule in .../tensorflow/python to build the wrapper library. def tf_gen_op_wrapper_py(name, out=None, hidden=None, visibility=None, deps=[], diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc index e2bdb842948..c4de14d7a8d 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -296,7 +296,7 @@ class QuantizeNodesTest : public ::testing::Test { Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor)); Output concat_op = - Concat(root.WithOpName("concat_op"), shape_op, {a_op, b_op}); + Concat(root.WithOpName("concat_op"), {a_op, b_op}, shape_op); GraphDef float_graph_def; TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));