Add a mechanism for hiding, skipping, and modifying the generated op

functions for C++.  A souped-up version of the hidden_ops mechanism in
Python, the intent is to use this for most or all of the client
languages, with a common list of changes to make in a common file and
per-language overrides.

Also:
* include the documentation for outputs in the generated comments
* several updates to C++ API to match Python
* fix C++ shape function for ConcatV2 now that we use it by default
* split op_gen_lib out of core:framework, since it is only used by
  the op generators, and I don't want to add another proto to
  mobile builds
Change: 146267344
This commit is contained in:
A. Unique TensorFlower 2017-02-01 11:23:07 -08:00 committed by TensorFlower Gardener
parent 287e845c52
commit 8fe32029f7
21 changed files with 557 additions and 78 deletions

View File

@ -228,6 +228,7 @@ cc_library(
srcs = ["gradients/array_grad.cc"],
deps = [
":cc_ops",
":cc_ops_internal",
":grad_op_registry",
":gradients",
],
@ -239,6 +240,7 @@ tf_cc_test(
deps = [
":array_grad",
":cc_ops",
":cc_ops_internal",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
@ -334,6 +336,7 @@ tf_gen_op_wrappers_cc(
"ops/const_op.h",
"ops/standard_ops.h",
],
override_file = "ops/op_gen_overrides.pbtxt",
pkg = "//tensorflow/core",
)
@ -387,6 +390,7 @@ cc_library_with_android_deps(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
],

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/attr_value_util.h"
@ -38,7 +39,7 @@ const int kRightMargin = 79;
// Converts:
// bazel-out/.../genfiles/(external/YYY/)?XX
// to: XX.
string GetPath(const std::string& dot_h_fname) {
string GetPath(const string& dot_h_fname) {
auto pos = dot_h_fname.find("/genfiles/");
string result = dot_h_fname;
if (pos != string::npos) {
@ -60,7 +61,7 @@ string GetPath(const std::string& dot_h_fname) {
// cc/ops/gen_foo_ops.h
// to:
// CC_OPS_GEN_FOO_OPS_H_
string ToGuard(const std::string& path) {
string ToGuard(const string& path) {
string guard;
guard.reserve(path.size() + 1); // + 1 -> trailing _
for (const char c : path) {
@ -360,7 +361,12 @@ bool HasOptionalAttrs(
}
struct OpInfo {
explicit OpInfo(const OpDef& op_def);
// graph_op_def: The OpDef used by the runtime, has the names that
// must be used when calling NodeBuilder.
// interface_op_def: The OpDef used in the interface in the generated
// code, with possibly overridden names and defaults.
explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def,
const std::vector<string>& aliases);
string GetOpAttrStruct() const;
string GetConstructorDecl(StringPiece op_name_prefix,
bool include_attr) const;
@ -378,11 +384,15 @@ struct OpInfo {
bool has_optional_attrs;
string comment;
const OpDef& graph_op_def;
const OpDef& op_def;
const std::vector<string>& aliases;
std::unordered_map<string, string> inferred_input_attrs;
};
OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
const std::vector<string>& a)
: graph_op_def(g_op_def), op_def(i_op_def), aliases(a) {
op_name = op_def.name();
InferOpAttributes(op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
@ -443,7 +453,6 @@ OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
strings::StrAppend(&comment, " ", attr.description(), "\n");
}
}
comment = MakeComment(comment, "");
for (int i = 0; i < op_def.output_arg_size(); ++i) {
const auto& arg = op_def.output_arg(i);
@ -453,13 +462,55 @@ OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
output_names.push_back(AvoidCPPKeywords(arg.name()));
is_list_output.push_back(is_list);
}
strings::StrAppend(&comment, "\nReturns:\n");
if (op_def.output_arg_size() == 0) { // No outputs.
strings::StrAppend(&comment, "* the created `Operation`\n");
} else if (op_def.output_arg_size() == 1) { // One output
if (is_list_output[0]) {
strings::StrAppend(&comment, "* `OutputList`: ");
} else {
strings::StrAppend(&comment, "* `Output`: ");
}
if (op_def.output_arg(0).description().empty()) {
strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(),
" tensor.\n");
} else {
// TODO(josh11b): Word wrap this.
strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n");
}
} else { // Multiple outputs.
for (int i = 0; i < op_def.output_arg_size(); ++i) {
if (is_list_output[i]) {
strings::StrAppend(&comment, "* `OutputList`");
} else {
strings::StrAppend(&comment, "* `Output`");
}
strings::StrAppend(&comment, " ", output_names[i]);
if (op_def.output_arg(i).description().empty()) {
strings::StrAppend(&comment, "\n");
} else {
// TODO(josh11b): Word wrap this.
strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(),
"\n");
}
}
}
if (!aliases.empty()) {
strings::StrAppend(&comment, "\nAliases:\n");
for (const auto& alias : aliases) {
strings::StrAppend(&comment, "* ", alias, "\n");
}
}
comment = MakeComment(comment, "");
}
string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
string attrs_comment = strings::StrCat("Optional attribute setters for ",
op_def.name(), " :\n\n");
string attrs_comment =
strings::StrCat("Optional attribute setters for ", op_name, " :\n\n");
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
@ -603,7 +654,13 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
";\n");
}
strings::StrAppend(&class_decl, "};\n\n");
strings::StrAppend(&class_decl, "};\n");
if (!aliases.empty()) {
for (const auto& alias : aliases) {
strings::StrAppend(&class_decl, "typedef ", op_name, " ", alias, ";\n");
}
}
strings::StrAppend(&class_decl, "\n");
TF_CHECK_OK(h->Append(class_decl));
}
@ -642,7 +699,7 @@ void OpInfo::GetOutput(string* out) const {
for (int i = 0; i < op_def.output_arg_size(); ++i) {
const string arg_range = strings::StrCat(
"_outputs_range[\"", op_def.output_arg(i).name(), "\"]");
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
if (is_list_output[i]) {
strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ",
arg_range, ".second; ++i)\n");
@ -673,17 +730,18 @@ string OpInfo::GetConstructorBody() const {
}
strings::StrAppend(&body, " ::tensorflow::Node* ret;\n");
strings::StrAppend(&body, " const auto unique_name = ", scope_str,
".GetUniqueNameForOp(\"", op_def.name(), "\");\n");
strings::StrAppend(&body, " const auto unique_name = ", scope_str,
".GetUniqueNameForOp(\"", op_name, "\");\n");
strings::StrAppend(
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
op_def.name(), "\")\n");
graph_op_def.name(), "\")\n");
const string spaces = " ";
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
}
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& graph_attr(graph_op_def.attr(i));
const auto& attr(op_def.attr(i));
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
continue;
@ -691,7 +749,7 @@ string OpInfo::GetConstructorBody() const {
const string attr_name = attr.has_default_value()
? strings::StrCat("attrs.", attr.name(), "_")
: AvoidCPPKeywords(attr.name());
strings::StrAppend(&body, spaces, ".Attr(\"", attr.name(), "\", ",
strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
attr_name, ")\n");
}
strings::StrAppend(&body, " ;\n");
@ -736,23 +794,17 @@ void OpInfo::WriteClassDef(WritableFile* cc) const {
TF_CHECK_OK(cc->Append(class_def));
}
void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
OpInfo op_info(op_def);
void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def,
const std::vector<string>& aliases, WritableFile* h,
WritableFile* cc) {
OpInfo op_info(graph_op_def, interface_op_def, aliases);
op_info.WriteClassDecl(h);
op_info.WriteClassDef(cc);
}
} // namespace
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
const std::string& dot_cc_fname) {
Env* env = Env::Default();
std::unique_ptr<WritableFile> h = nullptr;
std::unique_ptr<WritableFile> cc = nullptr;
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h,
WritableFile* cc, string* op_header_guard) {
const string header =
R"header(// This file is MACHINE GENERATED! Do not edit.
@ -765,18 +817,22 @@ void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
)header";
// TODO(keveman): Make namespaces configurable.
const string namespace_begin = R"namespace(
const string namespace_begin = internal ? R"namespace(
namespace tensorflow {
namespace ops {
namespace internal {
// NOTE: This namespace has internal TensorFlow details that
// are not part of TensorFlow's public API.
)namespace"
: R"namespace(
namespace tensorflow {
namespace ops {
)namespace";
const string footer = R"footer(} // namespace ops
} // namespace tensorflow
)footer";
const string op_header = GetPath(dot_h_fname);
const string op_header_guard = ToGuard(op_header);
*op_header_guard = ToGuard(op_header);
const string cc_header = strings::StrCat(
R"include(// This file is MACHINE GENERATED! Do not edit.
@ -788,27 +844,25 @@ namespace ops {
TF_CHECK_OK(h->Append(
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
"#ifndef ",
op_header_guard,
*op_header_guard,
"\n"
"#define ",
op_header_guard, "\n\n")));
*op_header_guard, "\n\n")));
TF_CHECK_OK(h->Append(header));
TF_CHECK_OK(h->Append(namespace_begin));
TF_CHECK_OK(cc->Append(cc_header));
}
for (const auto& op_def : ops.op()) {
// Skip deprecated ops.
// TODO(josh11b): If needed, can put them into a "deprecated" namespace
// instead of skipping.
if (op_def.has_deprecation() &&
op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
continue;
}
// We use a hand-written wrapper for "Const", since the generated
// code depends on it.
if (op_def.name() == "Const") continue;
WriteCCOp(op_def, h.get(), cc.get());
}
void FinishFiles(bool internal, WritableFile* h, WritableFile* cc,
const string& op_header_guard) {
const string footer = internal ? R"footer(} // namespace internal
} // namespace ops
} // namespace tensorflow
)footer"
:
R"footer(} // namespace ops
} // namespace tensorflow
)footer";
TF_CHECK_OK(h->Append(footer));
TF_CHECK_OK(
@ -819,4 +873,82 @@ namespace ops {
TF_CHECK_OK(h->Close());
}
string MakeInternal(const string& fname) {
auto dot_pos = fname.rfind('.');
if (dot_pos == string::npos) {
return strings::StrCat(fname, "_internal");
} else {
return strings::StrCat(fname.substr(0, dot_pos), "_internal",
fname.substr(dot_pos));
}
}
} // namespace
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
const string& dot_cc_fname, const string& overrides_fnames) {
Env* env = Env::Default();
// Load the override map.
OpGenOverrideMap override_map;
if (!overrides_fnames.empty()) {
override_map.LoadFileList(env, overrides_fnames);
}
// Write the initial boilerplate to the .h and .cc files.
std::unique_ptr<WritableFile> h = nullptr;
std::unique_ptr<WritableFile> cc = nullptr;
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
string op_header_guard;
StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard);
// Create the internal versions of these files for the hidden ops.
std::unique_ptr<WritableFile> internal_h = nullptr;
std::unique_ptr<WritableFile> internal_cc = nullptr;
const string internal_dot_h_fname = MakeInternal(dot_h_fname);
TF_CHECK_OK(env->NewWritableFile(internal_dot_h_fname, &internal_h));
TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc));
string internal_op_header_guard;
StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(),
internal_cc.get(), &internal_op_header_guard);
for (const auto& graph_op_def : ops.op()) {
// Skip deprecated ops.
// TODO(josh11b): If needed, can put them into a "deprecated" namespace
// instead of skipping.
if (graph_op_def.has_deprecation() &&
graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
continue;
}
// We use a hand-written wrapper for "Const", since the generated
// code depends on it.
if (graph_op_def.name() == "Const") continue;
// Incorporate overrides from override_map.
OpDef interface_op_def = graph_op_def;
const OpGenOverride* op_override =
override_map.ApplyOverride(&interface_op_def);
std::vector<string> aliases;
if (op_override) {
if (op_override->skip()) continue;
aliases.assign(op_override->alias().begin(), op_override->alias().end());
if (op_override->hide()) {
// Write hidden ops to _internal.h and _internal.cc.
WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(),
internal_cc.get());
continue;
}
}
// This isn't a hidden op, write it to the main files.
WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get());
}
FinishFiles(false, h.get(), cc.get(), op_header_guard);
FinishFiles(true /* internal */, internal_h.get(), internal_cc.get(),
internal_op_header_guard);
}
} // namespace tensorflow

View File

@ -17,12 +17,13 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
/// Result is written to files dot_h and dot_cc.
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
const std::string& dot_cc_fname);
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
const string& dot_cc_fname, const string& overrides_fnames);
} // namespace tensorflow

View File

@ -24,10 +24,10 @@ namespace tensorflow {
namespace {
void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
bool include_internal) {
const std::string& overrides_fnames, bool include_internal) {
OpList ops;
OpRegistry::Global()->Export(include_internal, &ops);
WriteCCOps(ops, dot_h, dot_cc);
WriteCCOps(ops, dot_h, dot_cc, overrides_fnames);
}
} // namespace
@ -35,15 +35,18 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc != 4) {
if (argc != 5) {
for (int i = 1; i < argc; ++i) {
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
}
fprintf(stderr,
"Usage: %s out.h out.cc include_internal\n"
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n"
" include_internal: 1 means include internal ops\n",
argv[0]);
exit(1);
}
bool include_internal = tensorflow::StringPiece("1") == argv[3];
tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal);
bool include_internal = tensorflow::StringPiece("1") == argv[4];
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal);
return 0;
}

View File

@ -68,7 +68,7 @@ TEST(CCOpTest, Attrs) {
TEST(CCOpTest, SplitConcat) {
Scope root = Scope::NewRootScope();
Split p(root, 0, {{1}, {2}}, 2);
auto c = Concat(root, 0, {p[0], p[1]});
auto c = Concat(root, {p[0], p[1]}, 0);
TF_EXPECT_OK(root.status());
Tensor out;
test::GetTensor(root, c, &out);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
@ -92,7 +93,7 @@ Status SplitGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(Concat(scope, op.input(0), grad_inputs));
grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
return scope.status();
}
REGISTER_GRADIENT_OP("Split", SplitGrad);
@ -219,7 +220,7 @@ Status ReverseGrad(const Scope& scope, const Operation& op,
grad_outputs->push_back(NoGradient());
return scope.status();
}
REGISTER_GRADIENT_OP("Reverse", ReverseGrad);
REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
Status ScatterNdGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
@ -319,8 +320,8 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
string mode;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
grad_outputs->push_back(
tensorflow::ops::MirrorPadGrad(scope, grad_inputs[0], op.input(1), mode));
grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
scope, grad_inputs[0], op.input(1), mode));
grad_outputs->push_back(NoGradient());
return scope.status();
}

View File

@ -17,12 +17,14 @@ limitations under the License.
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
using ops::internal::MirrorPadGrad;
namespace {
@ -207,8 +209,7 @@ TEST_F(ArrayGradTest, ReverseSequenceGrad) {
TEST_F(ArrayGradTest, ReverseGrad) {
TensorShape shape({5, 2, 5});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
auto reverse_dims = Const(scope_, {true, false, true});
auto y = Reverse(scope_, x, reverse_dims);
auto y = Reverse(scope_, x, {0, 2});
RunTest(x, shape, y, shape);
}

View File

@ -0,0 +1,42 @@
# array_ops
op { name: "BroadcastArgs" rename_to: "BroadcastDynamicShape" }
op { name: "ConcatOffset" skip: true } # Maybe should just be hidden?
op { name: "Concat" skip: true }
op { name: "ConcatV2" rename_to: "Concat" }
op { name: "MirrorPadGrad" hide: true }
op { name: "Reverse" skip: true }
op { name: "ReverseV2" rename_to: "Reverse" }
# candidate_sampling_ops
# control_flow_ops
# ctc_ops
# data_flow_ops
op { name: "FakeQueue" skip: true }
# functional_ops
# image_ops
# io_ops
# linalg_ops
# logging_ops
# math_ops
op { name: "All" alias: "ReduceAll" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "Any" alias: "ReduceAny" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "Max" alias: "ReduceMax" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "Mean" alias: "ReduceMean" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "Min" alias: "ReduceMin" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "Prod" alias: "ReduceProd" input_rename: { from: "reduction_indices" to: "axis" } }
op { name: "Sum" alias: "ReduceSum" input_rename: { from: "reduction_indices" to: "axis" } }
# nn_ops
op { name: "TopKV2" rename_to: "TopK" alias: "TopKV2" } # TODO(josh11b): delete "TopKV2" alias once users updated
# parsing_ops
# random_ops
# script_ops
# sdca_ops
# state_ops
# sparse_ops
# string_ops
# user_ops
# training_ops
# word2vec deprecated ops

View File

@ -454,6 +454,7 @@ $(wildcard tensorflow/core/*/*/*testutil*) \
$(wildcard tensorflow/core/*/*/*testlib*) \
$(wildcard tensorflow/core/*/*/*main.cc) \
$(wildcard tensorflow/core/debug/*.cc) \
$(wildcard tensorflow/core/framework/op_gen_lib.cc) \
$(wildcard tensorflow/core/graph/dot.*) \
$(wildcard tensorflow/core/lib/gif/*) \
$(wildcard tensorflow/core/lib/io/zlib*) \

View File

@ -328,7 +328,6 @@ tf_cuda_library(
"framework/op.h",
"framework/op_def_builder.h",
"framework/op_def_util.h",
"framework/op_gen_lib.h",
"framework/op_kernel.h",
"framework/partial_tensor_shape.h",
"framework/queue_interface.h",
@ -384,6 +383,25 @@ tf_cuda_library(
deps = [":framework_internal"],
)
tf_proto_library_cc(
name = "op_gen_overrides_proto",
srcs = ["framework/op_gen_overrides.proto"],
cc_api_version = 2,
protodeps = [":protos_all"],
visibility = ["//visibility:public"],
)
cc_library(
name = "op_gen_lib",
srcs = ["framework/op_gen_lib.cc"],
hdrs = ["framework/op_gen_lib.h"],
visibility = ["//visibility:public"],
deps = [
":lib",
":op_gen_overrides_proto_cc",
],
)
cc_library(
name = "framework_lite",
srcs = tf_additional_minimal_lib_srcs(),
@ -761,6 +779,7 @@ filegroup(
"**/*testlib*",
"**/*main.cc",
"debug/**/*",
"framework/op_gen_*",
"graph/dot.*",
"lib/jpeg/**/*",
"lib/png/**/*",
@ -1250,6 +1269,7 @@ tf_cuda_library(
"util/reporter.h",
"util/reporter.cc",
"framework/fake_input.*",
"framework/op_gen_lib.*",
"util/memmapped_file_system.*",
"util/memmapped_file_system_writer.*",
"util/version_info.cc",

View File

@ -397,18 +397,22 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
}
}
*result = target_context->MakeShape(dims);
} else if (src_op == "Concat") {
} else if (src_op == "Concat" || src_op == "ConcatV2") {
*result = target_context->Scalar();
// For Concat, input 0 is concat dim; for V2 it is the last input.
const int concat_dim =
src_op == "Concat" ? 0 : src_context->num_inputs() - 1;
// Concat is concatenating its input shape vectors.
// input 0 is ignored as it is the concat dim and will always be 0.
for (int i = 1; i < src_context->num_inputs(); ++i) {
for (int i = 0; i < src_context->num_inputs(); ++i) {
// Concat dim is ignored (and will always be a scalar).
if (i == concat_dim) continue;
ShapeHandle sub_result;
TF_RETURN_IF_ERROR(ConstantPartialShape(target_context, input_edge->src(),
i, &sub_result));
if (!target_context->RankKnown(sub_result)) {
// Failed to evaluate. Treat the output as completely unknown.
// TODO(cwhipkey): we could rely on all inputs being the same size, so
// figure that size out and append the right number of unknown dims.
// TODO(cwhipkey): we could rely on all inputs being the same rank, so
// figure that rank out and append the right number of unknown dims.
*result = target_context->UnknownShape();
return Status::OK();
}

View File

@ -642,7 +642,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) {
const_input,
}; // clang-format on
auto concat_dim = ops::Const(root, 0);
auto concat = ops::Concat(root, concat_dim, concat_inputs);
auto concat = ops::Concat(root, concat_inputs, concat_dim);
TF_ASSERT_OK(root.status());
Node* result;
@ -684,7 +684,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
ops::Shape(root, Output(unknown)),
}; // clang-format on
auto concat_dim = ops::Const(root, 0);
auto concat = ops::Concat(root, concat_dim, concat_inputs);
auto concat = ops::Concat(root, concat_inputs, concat_dim);
TF_ASSERT_OK(root.status());
Node* result;
@ -726,7 +726,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
const_input,
}; // clang-format on
auto concat_dim = ops::Const(root, 0);
auto concat = ops::Concat(root, concat_dim, concat_inputs);
auto concat = ops::Concat(root, concat_inputs, concat_dim);
TF_ASSERT_OK(root.status());
Node* result;

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/framework/op_gen_lib.h"
#include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@ -67,4 +70,136 @@ bool ConsumeEquals(StringPiece* description) {
return false;
}
Status OpGenOverrideMap::LoadFileList(Env* env, const string& filenames) {
std::vector<string> v = str_util::Split(filenames, ",");
for (const string& f : v) {
TF_RETURN_IF_ERROR(LoadFile(env, f));
}
return Status::OK();
}
Status OpGenOverrideMap::LoadFile(Env* env, const string& filename) {
if (filename.empty()) return Status::OK();
string contents;
TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
OpGenOverrides all;
protobuf::TextFormat::ParseFromString(contents, &all);
for (const auto& one : all.op()) {
map_[one.name()] = one;
}
return Status::OK();
}
static void StringReplace(const string& from, const string& to, string* s) {
std::vector<string> rest = str_util::Split(*s, from);
*s = str_util::Join(rest, to.c_str());
}
static void RenameInDocs(const string& from, const string& to, OpDef* op_def) {
const string from_quoted = strings::StrCat("`", from, "`");
const string to_quoted = strings::StrCat("`", to, "`");
for (int i = 0; i < op_def->input_arg_size(); ++i) {
if (!op_def->input_arg(i).description().empty()) {
StringReplace(from_quoted, to_quoted,
op_def->mutable_input_arg(i)->mutable_description());
}
}
for (int i = 0; i < op_def->output_arg_size(); ++i) {
if (!op_def->output_arg(i).description().empty()) {
StringReplace(from_quoted, to_quoted,
op_def->mutable_output_arg(i)->mutable_description());
}
}
for (int i = 0; i < op_def->attr_size(); ++i) {
if (!op_def->attr(i).description().empty()) {
StringReplace(from_quoted, to_quoted,
op_def->mutable_attr(i)->mutable_description());
}
}
if (!op_def->summary().empty()) {
StringReplace(from_quoted, to_quoted, op_def->mutable_summary());
}
if (!op_def->description().empty()) {
StringReplace(from_quoted, to_quoted, op_def->mutable_summary());
}
}
const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const {
// Look up
const auto iter = map_.find(op_def->name());
if (iter == map_.end()) return nullptr;
const OpGenOverride& proto = iter->second;
// Apply overrides from `proto`.
if (!proto.rename_to().empty()) {
op_def->set_name(proto.rename_to());
RenameInDocs(proto.name(), proto.rename_to(), op_def);
}
for (const auto& attr_default : proto.attr_default()) {
bool found = false;
for (int i = 0; i < op_def->attr_size(); ++i) {
if (op_def->attr(i).name() == attr_default.name()) {
*op_def->mutable_attr(i)->mutable_default_value() =
attr_default.value();
found = true;
break;
}
}
if (!found) {
LOG(WARNING) << proto.name() << " can't find attr " << attr_default.name()
<< " to override default";
}
}
for (const auto& attr_rename : proto.attr_rename()) {
bool found = false;
for (int i = 0; i < op_def->attr_size(); ++i) {
if (op_def->attr(i).name() == attr_rename.from()) {
*op_def->mutable_attr(i)->mutable_name() = attr_rename.to();
found = true;
break;
}
}
if (found) {
RenameInDocs(attr_rename.from(), attr_rename.to(), op_def);
} else {
LOG(WARNING) << proto.name() << " can't find attr " << attr_rename.from()
<< " to rename";
}
}
for (const auto& input_rename : proto.input_rename()) {
bool found = false;
for (int i = 0; i < op_def->input_arg_size(); ++i) {
if (op_def->input_arg(i).name() == input_rename.from()) {
*op_def->mutable_input_arg(i)->mutable_name() = input_rename.to();
found = true;
break;
}
}
if (found) {
RenameInDocs(input_rename.from(), input_rename.to(), op_def);
} else {
LOG(WARNING) << proto.name() << " can't find input "
<< input_rename.from() << " to rename";
}
}
for (const auto& output_rename : proto.output_rename()) {
bool found = false;
for (int i = 0; i < op_def->output_arg_size(); ++i) {
if (op_def->output_arg(i).name() == output_rename.from()) {
*op_def->mutable_output_arg(i)->mutable_name() = output_rename.to();
found = true;
break;
}
}
if (found) {
RenameInDocs(output_rename.from(), output_rename.to(), op_def);
} else {
LOG(WARNING) << proto.name() << " can't find output "
<< output_rename.from() << " to rename";
}
}
return &proto;
}
} // namespace tensorflow

View File

@ -17,7 +17,12 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
#include <string>
#include <unordered_map>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
@ -34,6 +39,31 @@ string WordWrap(StringPiece prefix, StringPiece str, int width);
// returns false.
bool ConsumeEquals(StringPiece* description);
// Takes a list of files with OpGenOverrides text protos, and allows you to
// look up the specific override for any given op.
class OpGenOverrideMap {
public:
// `filenames` is a comma-separated list of file names. If an op
// is mentioned in more than one file, the last one takes priority.
Status LoadFileList(Env* env, const string& filenames);
// Load a single file. If more than one file is loaded, later ones
// take priority for any ops in common.
Status LoadFile(Env* env, const string& filename);
// Look up the override for `*op_def` from the loaded files, and
// mutate `*op_def` to reflect the requested changes. Does not apply
// 'skip', 'hide', or 'alias' overrides. Caller has to deal with
// those since they can't be simulated by mutating `*op_def`.
// Returns nullptr if op is not in any loaded file. Otherwise, the
// pointer must not be referenced beyond the lifetime of *this or
// the next file load.
const OpGenOverride* ApplyOverride(OpDef* op_def) const;
private:
std::unordered_map<string, OpGenOverride> map_;
};
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_

View File

@ -0,0 +1,67 @@
// Defines the text format for adding per-op overrides for client
// language op code generators.
syntax = "proto3";
package tensorflow;
import "tensorflow/core/framework/attr_value.proto";
// Used to override the default API & behavior in the generated code
// for client languages, from what you would get from the OpDef alone.
// This is so we can evolve the API while remaining backwards
// compatible when interpretting old graphs. Overrides go in an
// "op_gen_overrides.pbtxt" file with a text-format OpGenOverrides
// message. Right now these only apply to the C++ API.
// TODO(josh11b): In the future there will be a common set of overrides
// and per-client-language overrides.
//
// WARNING: Be *very* careful using these features -- these overrides
// can change the semantics of existing code. These changes may need
// to wait until a major release of TensorFlow to avoid breaking our
// compatibility promises.
message OpGenOverride {
// Name of the op to apply overrides to.
string name = 1;
// Do not include this op in the generated API.
// If `skip` is true, all other overrides are ignored for this op.
bool skip = 2;
// Hide this op by putting it into an internal namespace (or whatever
// is appropriate in the target language).
bool hide = 3;
// Use a different name in the API than the op's name. Note that
// the op's name in `backticks` will also be replaced in the docs.
string rename_to = 4;
// Create *additional* API endpoints with different names (contrast
// with rename_to, which affects the original name).
repeated string alias = 5;
// Map the name of an attr to a new default value to use. This
// default will be used when creating new graphs, as opposed to the
// default in the OpDef, which will be used when interpreting old
// GraphDefs. If this attr is also renamed (using attr_rename
// below), use the original name of the attr.
message AttrDefault {
string name = 1;
AttrValue value = 2;
}
repeated AttrDefault attr_default = 6;
// Change the name used to access attrs/inputs/outputs in the API
// from what is used in the GraphDef. Note that these names in
// `backticks` will also be replaced in the docs.
message Rename {
string from = 1;
string to = 2;
}
repeated Rename attr_rename = 7;
repeated Rename input_rename = 8;
repeated Rename output_rename = 9;
}
message OpGenOverrides {
repeated OpGenOverride op = 1;
}

View File

@ -1860,6 +1860,7 @@ values: The `k` largest elements along each last dimensional slice.
indices: The indices of `values` within the last dimension of `input`.
)doc");
// This is the same as `TopK`, but takes `k` as in input rather than an attr.
REGISTER_OP("TopKV2")
.Input("input: T")
.Input("k: int32")
@ -1882,8 +1883,6 @@ row (resp. vector along the last dimension). Thus,
If two elements are equal, the lower-index element appears first.
This is the same as `TopK`, but takes `k` as in input rather than an attr.
input: 1-D or higher with last dimension at least `k`.
k: 0-D. Number of top elements to look for along the last dimension (along each
row for matrices).

View File

@ -163,7 +163,7 @@ Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
string output_name = "top_k";
TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels);
TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
// This runs the GraphDef network definition that we've just constructed, and
// returns the results in the output tensors.
tensorflow::GraphDef graph;

View File

@ -184,7 +184,7 @@ Status GetTopDetections(const std::vector<Tensor>& outputs, int how_many_labels,
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
string output_name = "top_k";
TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels);
TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
// This runs the GraphDef network definition that we've just constructed, and
// returns the results in the output tensors.
tensorflow::GraphDef graph;

View File

@ -315,6 +315,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_cc",
],
alwayslink = 1,

View File

@ -143,6 +143,7 @@ def tf_gen_op_libs(op_lib_names, deps=None):
def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="",
op_gen="//tensorflow/cc:cc_op_gen_main",
deps=None,
override_file=None,
include_internal_ops=0):
# Construct an op generator binary for these ops.
tool = out_ops_file + "_gen_cc"
@ -156,12 +157,21 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="",
deps = [op_gen] + deps
)
if override_file == None:
srcs = []
override_arg = ","
else:
srcs = [override_file]
override_arg = "$(location " + override_file + ")"
native.genrule(
name=name + "_genrule",
outs=[out_ops_file + ".h", out_ops_file + ".cc"],
outs=[out_ops_file + ".h", out_ops_file + ".cc",
out_ops_file + "_internal.h", out_ops_file + "_internal.cc"],
srcs=srcs,
tools=[":" + tool],
cmd=("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " +
"$(location :" + out_ops_file + ".cc) " + str(include_internal_ops)))
"$(location :" + out_ops_file + ".cc) " + override_arg + " " +
str(include_internal_ops)))
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate individual C++ .cc and .h
@ -181,6 +191,15 @@ def tf_gen_op_wrapper_cc(name, out_ops_file, pkg="",
# hdrs = [ "ops/array_ops.h",
# "ops/math_ops.h" ],
# deps = [ ... ])
#
# Plus a private library for the "hidden" ops.
# cc_library(name = "tf_ops_lib_internal",
# srcs = [ "ops/array_ops_internal.cc",
# "ops/math_ops_internal.cc" ],
# hdrs = [ "ops/array_ops_internal.h",
# "ops/math_ops_internal.h" ],
# deps = [ ... ])
# TODO(josh11b): Cleaner approach for hidden ops.
def tf_gen_op_wrappers_cc(name,
op_lib_names=[],
other_srcs=[],
@ -192,16 +211,21 @@ def tf_gen_op_wrappers_cc(name,
"//tensorflow/cc:const_op",
],
op_gen="//tensorflow/cc:cc_op_gen_main",
override_file=None,
include_internal_ops=0,
visibility=None):
subsrcs = other_srcs
subhdrs = other_hdrs
internalsrcs = []
internalhdrs = []
for n in op_lib_names:
tf_gen_op_wrapper_cc(
n, "ops/" + n, pkg=pkg, op_gen=op_gen,
n, "ops/" + n, pkg=pkg, op_gen=op_gen, override_file=override_file,
include_internal_ops=include_internal_ops)
subsrcs += ["ops/" + n + ".cc"]
subhdrs += ["ops/" + n + ".h"]
internalsrcs += ["ops/" + n + "_internal.cc"]
internalhdrs += ["ops/" + n + "_internal.h"]
native.cc_library(name=name,
srcs=subsrcs,
@ -217,6 +241,20 @@ def tf_gen_op_wrappers_cc(name,
copts=tf_copts(),
alwayslink=1,
visibility=visibility)
native.cc_library(name=name + "_internal",
srcs=internalsrcs,
hdrs=internalhdrs,
deps=deps + if_not_android([
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
]) + if_android([
"//tensorflow/core:android_tensorflow_lib",
]),
copts=tf_copts(),
alwayslink=1,
visibility=["//visibility:private"])
# Invoke this rule in .../tensorflow/python to build the wrapper library.
def tf_gen_op_wrapper_py(name, out=None, hidden=None, visibility=None, deps=[],

View File

@ -296,7 +296,7 @@ class QuantizeNodesTest : public ::testing::Test {
Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor));
Output concat_op =
Concat(root.WithOpName("concat_op"), shape_op, {a_op, b_op});
Concat(root.WithOpName("concat_op"), {a_op, b_op}, shape_op);
GraphDef float_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));