Improvements to the C++ graph building API.

TESTED:
- passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/
Change: 127585603
This commit is contained in:
Manjunath Kudlur 2016-07-15 14:28:59 -08:00 committed by TensorFlower Gardener
parent 194efde518
commit 25ac3dabfa
28 changed files with 2716 additions and 785 deletions

View File

@ -4,6 +4,9 @@
* Connectionist Temporal Classification ops are now "official" (see, e.g.,
`tf.nn.ctc_loss`)
* Preliminary graph-construction C API, for use by language bindings.
* Major revision to the graph-construction C++ API. Scoping mechanism to make op
naming, specifying control dependencies etc. more consistent. C++ values can
be used directly as operands, making op construction more concise.
## Breaking Changes to the API
* `env.h` replaces use of `New*File()` functions to use `std::unique_ptr`

View File

@ -2,29 +2,77 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
package(default_visibility = ["//visibility:public"])
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrappers_cc")
cc_library(
name = "cc_op_gen_main",
srcs = [
"ops/cc_op_gen.cc",
"ops/cc_op_gen_main.cc",
],
hdrs = ["ops/cc_op_gen.h"],
copts = tf_copts(),
name = "ops",
srcs = ["framework/ops.cc"],
hdrs = ["framework/ops.h"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "scope",
srcs = ["framework/scope.cc"],
hdrs = ["framework/scope.h"],
deps = [
":ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "framework/scope_test.cc",
deps = [
":ops",
":scope",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "const_op",
srcs = ["ops/const_op.cc"],
hdrs = ["ops/const_op.h"],
deps = [
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
],
)
# Generates a library that contains C++ wrappers for ops.
tf_cc_test(
name = "ops/const_op_test.cc",
deps = [
":const_op",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_gen_op_wrappers_cc(
name = "cc_ops",
op_lib_names = [
@ -41,7 +89,6 @@ tf_gen_op_wrappers_cc(
"no_op",
"parsing_ops",
"random_ops",
"sendrecv_ops",
"sparse_ops",
"state_ops",
"string_ops",
@ -52,12 +99,63 @@ tf_gen_op_wrappers_cc(
"ops/const_op.h",
"ops/standard_ops.h",
],
other_srcs = [
"ops/const_op.cc",
] + glob(["ops/*_grad.cc"]),
pkg = "//tensorflow/core",
)
tf_cc_test(
name = "framework/cc_ops_test.cc",
deps = [
":cc_ops",
":test_op",
":test_op_op_lib",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_gen_op_wrappers_cc(
name = "sendrecv_ops",
op_lib_names = [
"sendrecv_ops",
],
pkg = "//tensorflow/core",
)
cc_library(
name = "cc_op_gen_main",
srcs = [
"framework/cc_op_gen.cc",
"framework/cc_op_gen.h",
"framework/cc_op_gen_main.cc",
],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "test_op_op_lib",
srcs = ["framework/test_op.cc"],
linkstatic = 1,
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
tf_gen_op_wrappers_cc(
name = "test_op",
op_lib_names = [
"test_op",
],
)
cc_binary(
name = "tutorials_example_trainer",
srcs = ["tutorials/example_trainer.cc"],
@ -69,6 +167,10 @@ cc_binary(
deps = [
":cc_ops",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)

View File

@ -0,0 +1,798 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/types.pb_text.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
const int kRightMargin = 79;
// Converts:
// bazel-out/.../genfiles/(external/YYY/)?XX
// to: XX.
string GetPath(const std::string& dot_h_fname) {
auto pos = dot_h_fname.find("/genfiles/");
string result = dot_h_fname;
if (pos != string::npos) {
// - 1 account for the terminating null character (\0) in "/genfiles/".
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
}
if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
result = result.substr(sizeof("external/") - 1);
pos = result.find("/");
if (pos != string::npos) {
result = result.substr(pos + 1);
}
}
return result;
}
// Converts:
// cc/ops/gen_foo_ops.h
// to:
// CC_OPS_GEN_FOO_OPS_H_
string ToGuard(const std::string& path) {
string guard;
guard.reserve(path.size() + 1); // + 1 -> trailing _
for (const char c : path) {
if (c >= 'A' && c <= 'Z') {
guard += c;
} else if (c >= 'a' && c <= 'z') {
guard += c + 'A' - 'a';
} else {
guard += '_';
}
}
guard += '_';
return guard;
}
// Change: Into:
// ABC // ABC
// //
// DEF // DEF
string MakeComment(StringPiece text, StringPiece indent) {
string ret;
while (!text.empty()) {
int last_non_space = -1;
int newline;
for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
if (text[newline] == '\n') break;
if (text[newline] != ' ') last_non_space = newline;
}
if (last_non_space == -1) {
strings::StrAppend(&ret, indent, "//\n");
} else {
strings::StrAppend(&ret, indent, "// ",
text.substr(0, last_non_space + 1), "\n");
}
text.remove_prefix(newline + 1);
}
return ret;
}
string PrintString(const string& str) {
return strings::StrCat("\"", str_util::CEscape(str), "\"");
}
string PrintTensorShape(const TensorShape& shape) {
string ret = "{";
for (int d = 0; d < shape.dims(); ++d) {
if (d > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, shape.dim_size(d));
}
strings::StrAppend(&ret, "}");
return ret;
}
template <typename T>
string PrintArray(int64 num_elts, const T* array) {
string ret;
for (int64 i = 0; i < num_elts; ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, array[i]);
}
return ret;
}
string PrintTensor(const TensorProto& tensor_proto) {
Tensor t(tensor_proto.dtype());
CHECK(t.FromProto(tensor_proto));
const int64 num_elts = t.NumElements();
switch (t.dtype()) {
case DT_FLOAT:
return PrintArray(num_elts, t.flat<float>().data());
case DT_DOUBLE:
return PrintArray(num_elts, t.flat<double>().data());
case DT_INT32:
return PrintArray(num_elts, t.flat<int32>().data());
case DT_UINT8:
case DT_QUINT8:
return PrintArray(num_elts, t.flat<uint8>().data());
case DT_UINT16:
case DT_QUINT16:
return PrintArray(num_elts, t.flat<uint16>().data());
case DT_INT16:
case DT_QINT16:
return PrintArray(num_elts, t.flat<int16>().data());
case DT_INT8:
case DT_QINT8:
return PrintArray(num_elts, t.flat<int8>().data());
case DT_INT64:
return PrintArray(num_elts, t.flat<int64>().data());
case DT_BOOL:
return PrintArray(num_elts, t.flat<bool>().data());
case DT_STRING: {
string ret;
for (int64 i = 0; i < num_elts; ++i) {
if (i > 0) strings::StrAppend(&ret, " ");
strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
}
return ret;
}
default: {
LOG(FATAL) << "Not handling type " << EnumName_DataType(t.dtype());
return string();
}
}
}
string PrintAttrValue(string op, const AttrValue& attr_value) {
switch (attr_value.value_case()) {
case AttrValue::kS:
return PrintString(attr_value.s());
case AttrValue::kI:
return strings::StrCat(attr_value.i());
case AttrValue::kF:
return strings::StrCat(attr_value.f());
case AttrValue::kB:
return attr_value.b() ? "true" : "false";
case AttrValue::kType:
return EnumName_DataType(attr_value.type());
case AttrValue::kShape:
return PrintTensorShape(TensorShape(attr_value.shape()));
case AttrValue::kTensor:
return strings::StrCat(
"Input::Initializer(", "{", PrintTensor(attr_value.tensor()), "}, ",
PrintTensorShape(TensorShape(attr_value.tensor().tensor_shape())),
").AsTensorProto()");
case AttrValue::kList: {
string ret = "{";
if (attr_value.list().s_size() > 0) {
for (int i = 0; i < attr_value.list().s_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, PrintString(attr_value.list().s(i)));
}
} else if (attr_value.list().i_size() > 0) {
for (int i = 0; i < attr_value.list().i_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, attr_value.list().i(i));
}
} else if (attr_value.list().f_size() > 0) {
for (int i = 0; i < attr_value.list().f_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, attr_value.list().f(i));
}
} else if (attr_value.list().b_size() > 0) {
for (int i = 0; i < attr_value.list().b_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
}
} else if (attr_value.list().type_size() > 0) {
for (int i = 0; i < attr_value.list().type_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret,
EnumName_DataType(attr_value.list().type(i)));
}
} else if (attr_value.list().shape_size() > 0) {
for (int i = 0; i < attr_value.list().shape_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(
&ret, PrintTensorShape(TensorShape(attr_value.list().shape(i))));
}
}
strings::StrAppend(&ret, "}");
return ret;
}
default:
LOG(FATAL) << "Unsupported Attr type: " << op << " "
<< attr_value.value_case();
}
return "<Unknown AttrValue type>"; // Prevent missing return warning
}
string ToCamelCase(const string& str) {
string result;
const char joiner = '_';
int i = 0;
bool cap = true;
while (i < str.size()) {
const char c = str[i++];
if (c == joiner) {
cap = true;
} else if (cap) {
result += toupper(c);
cap = false;
} else {
result += c;
}
}
return result;
}
// Returns a <string, bool> pair. The string is the C++ type name to be used for
// attr_type when defining an object of that type. The bool is a flag to
// indicate whether to treat the type as const when accepting the C++ type as an
// argument to a function.
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
StringPiece::Hasher>
attr_type_map{
{"string", {"StringPiece", false}},
{"list(string)", {"gtl::ArraySlice<string>", true}},
{"int", {"int64", false}},
{"list(int)", {"gtl::ArraySlice<int>", true}},
{"float", {"float", false}},
{"list(float)", {"gtl::ArraySlice<float>", true}},
{"bool", {"bool", false}},
{"list(bool)", {"gtl::ArraySlice<bool>", true}},
{"type", {"DataType", false}},
{"list(type)", {"DataTypeSlice", true}},
{"shape", {"TensorShape", false}},
{"list(shape)", {"gtl::ArraySlice<TensorShape>", true}},
{"tensor", {"TensorProto", true}},
{"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}},
{"func", {"NameAttrList", true}},
};
auto entry = attr_type_map.find(attr_type);
if (entry == attr_type_map.end()) {
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
return {"", false};
}
return entry->second;
}
bool IsCPPKeyword(StringPiece name) {
static const std::unordered_set<StringPiece, StringPiece::Hasher>
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
kCPPReserved{
"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel",
"atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool",
"break", "case", "catch", "char", "char16_t", "char32_t", "class",
"compl", "concept", "const", "const_cast", "constexpr", "continue",
"decltype", "default", "delete", "do", "double", "dynamic_cast",
"else", "enum", "explicit", "export", "extern", "false", "final",
"float", "for", "friend", "goto", "if", "import", "inline", "int",
"long", "module", "mutable", "namespace", "new", "noexcept", "not",
"not_eq", "nullptr", "operator", "or", "or_eq", "override", "private",
"protected", "public", "register", "reinterpret_cast", "requires",
"return", "short", "signed", "sizeof", "static", "static_assert",
"static_cast", "struct", "switch", "synchronized", "template", "this",
"thread_local", "throw", "true", "try", "typedef", "typeid",
"typename", "union", "unsigned", "using", "virtual", "void",
"volatile", "wchar_t", "while", "xor", "xor_eq",
// The following are not C++ keywords, but names of local variables
// and parameters used in the op constructor. Treating them as
// keywords, so that other parameter names don't conflict with these.
"builder", "node", "ret", "scope", "unique_name",
};
return kCPPReserved.count(name) > 0;
}
string AvoidCPPKeywords(StringPiece name) {
if (IsCPPKeyword(name)) {
return strings::StrCat(name, "_");
}
return name.ToString();
}
void InferArgAttributes(const OpDef::ArgDef& arg,
std::unordered_map<string, string>* inferred_attrs) {
if (!arg.type_attr().empty()) {
gtl::InsertIfNotPresent(inferred_attrs, arg.type_attr(), arg.name());
} else if (!arg.type_list_attr().empty()) {
gtl::InsertIfNotPresent(inferred_attrs, arg.type_list_attr(), arg.name());
}
if (!arg.number_attr().empty()) {
gtl::InsertIfNotPresent(inferred_attrs, arg.number_attr(), arg.name());
}
}
void InferOpAttributes(
const OpDef& op_def,
std::unordered_map<string, string>* inferred_input_attrs) {
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
InferArgAttributes(arg, inferred_input_attrs);
}
}
bool ArgIsList(const OpDef::ArgDef& arg) {
return !arg.type_list_attr().empty() || !arg.number_attr().empty();
}
bool HasOptionalAttrs(
const OpDef& op_def,
const std::unordered_map<string, string>& inferred_input_attrs) {
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) ==
inferred_input_attrs.end()) &&
attr.has_default_value()) {
return true;
}
}
return false;
}
struct OpInfo {
explicit OpInfo(const OpDef& op_def);
string GetOpAttrStruct() const;
string GetConstructorDecl(StringPiece op_name_prefix,
bool include_attr) const;
void WriteClassDecl(WritableFile* h) const;
void GetOutput(string* out) const;
string GetConstructorBody() const;
void WriteClassDef(WritableFile* cc) const;
string op_name;
std::vector<string> arg_types;
std::vector<string> arg_names;
std::vector<string> output_types;
std::vector<string> output_names;
std::vector<bool> is_list_output;
bool has_optional_attrs;
string comment;
const OpDef& op_def;
std::unordered_map<string, string> inferred_input_attrs;
};
OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
op_name = op_def.name();
InferOpAttributes(op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
arg_types.push_back("const ::tensorflow::Scope&");
arg_names.push_back("scope");
if (op_def.summary().empty()) {
comment = "TODO: add doc.\n";
} else {
comment = strings::StrCat(op_def.summary(), "\n");
if (op_def.has_deprecation()) {
strings::StrAppend(&comment, "\nDEPRECATED at GraphDef version ",
op_def.deprecation().version(), ":\n",
op_def.deprecation().explanation(), ".\n");
}
if (!op_def.description().empty()) {
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
}
}
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
arg_types.push_back(strings::StrCat(
"::tensorflow::ops::", ArgIsList(arg) ? "InputList" : "Input"));
arg_names.push_back(AvoidCPPKeywords(arg.name()));
// TODO(keveman): Include input type information.
StringPiece description = arg.description();
if (!description.empty()) {
ConsumeEquals(&description);
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ",
arg.description(), "\n");
}
}
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
// If the attr is going to be inferred or is optional, don't add it as a
// required argument.
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
arg_types.push_back(strings::StrCat(use_const ? "const " : "",
attr_type_name, use_const ? "&" : ""));
arg_names.push_back(AvoidCPPKeywords(attr.name()));
if (!attr.description().empty()) {
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(attr.name()), ":\n");
// TODO(keveman): Word wrap and indent this, to handle multi-line
// descriptions.
strings::StrAppend(&comment, " ", attr.description(), "\n");
}
}
comment = MakeComment(comment, "");
for (int i = 0; i < op_def.output_arg_size(); ++i) {
const auto& arg = op_def.output_arg(i);
bool is_list = ArgIsList(arg);
output_types.push_back(strings::StrCat("::tensorflow::ops::",
is_list ? "OutputList" : "Output"));
output_names.push_back(AvoidCPPKeywords(arg.name()));
is_list_output.push_back(is_list);
}
}
string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
string attrs_comment = strings::StrCat("Optional attribute setters for ",
op_def.name(), " :\n\n");
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
// If attr will be inferred or it doesn't have a default value, don't
// add it to the struct.
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
!attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
const string camel_case_name = ToCamelCase(attr.name());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def =
strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "",
attr_type_name, use_const ? "&" : "");
strings::StrAppend(&attrs_comment, attr_func_def, "): Defaults to ",
SummarizeAttrValue(attr.default_value()), "\n");
if (!attr.description().empty()) {
// TODO(keveman): Word wrap and indent this to handle multi-line
// description.
strings::StrAppend(&attrs_comment, " ", attr.description(), "\n");
}
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
strings::StrAppend(&setters, " Attrs ret = *this;\n");
strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
strings::StrAppend(
&struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ",
PrintAttrValue(op_def.name(), attr.default_value()), ";\n");
}
if (struct_fields.empty()) {
return "";
}
string struct_decl = MakeComment(attrs_comment, " ");
strings::StrAppend(&struct_decl, " struct Attrs {\n");
strings::StrAppend(&struct_decl, setters, struct_fields);
strings::StrAppend(&struct_decl, " };\n");
return struct_decl;
}
string OpInfo::GetConstructorDecl(StringPiece op_name_prefix,
bool include_attr) const {
const string prefix = strings::StrCat(op_name_prefix, op_name, "(");
string c_decl;
for (int i = 0; i < arg_types.size(); ++i) {
if (i > 0) strings::StrAppend(&c_decl, ", ");
strings::StrAppend(&c_decl, arg_types[i], " ", arg_names[i]);
}
if (include_attr && has_optional_attrs) {
strings::StrAppend(&c_decl, ", const ", op_name, "::Attrs& attrs");
}
strings::StrAppend(&c_decl, ")");
return WordWrap(prefix, c_decl, kRightMargin);
}
void OpInfo::WriteClassDecl(WritableFile* h) const {
string class_decl = comment;
strings::StrAppend(&class_decl, "class ", op_name, " {\n");
strings::StrAppend(&class_decl, " public:\n");
if (has_optional_attrs) {
strings::StrAppend(&class_decl, GetOpAttrStruct());
}
strings::StrAppend(&class_decl, " ",
GetConstructorDecl("", /* include_attr */ false), ";\n");
if (has_optional_attrs) {
strings::StrAppend(&class_decl, " ",
GetConstructorDecl("", /* include_attr */ true), ";\n");
}
if (output_types.empty()) {
// Allow casting this class to Operation.
strings::StrAppend(&class_decl,
" operator ::tensorflow::ops::Operation() const { "
"return operation; }\n");
} else if (output_types.size() == 1) {
if (is_list_output[0]) {
// Write the subscript operator, allowing out[i] for the list-typed
// output.
strings::StrAppend(&class_decl,
" ::tensorflow::ops::Output operator[](size_t index) "
"const { return ",
output_names[0], "[index]; }\n\n");
} else {
// Write type cast functions, allowing casting this class to Input and
// Output.
strings::StrAppend(
&class_decl, " operator ::tensorflow::ops::Output() const { return ",
output_names[0], "; }\n");
strings::StrAppend(
&class_decl, " operator ::tensorflow::ops::Input() const { return ",
output_names[0], "; }\n");
// Write node() to get the Node* directly.
strings::StrAppend(&class_decl,
" ::tensorflow::Node* node() const { return ",
output_names[0], ".node(); }\n");
}
}
// Add the static functions to set optional attrs
if (has_optional_attrs) {
strings::StrAppend(&class_decl, "\n");
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
!attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
const string camel_case_name = ToCamelCase(attr.name());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def = strings::StrCat(
camel_case_name, suffix, "(", use_const ? "const " : "",
attr_type_name, use_const ? "&" : "");
strings::StrAppend(&class_decl, " static Attrs ", attr_func_def,
" x) {\n");
strings::StrAppend(&class_decl, " return Attrs().", camel_case_name,
suffix, "(x);\n");
strings::StrAppend(&class_decl, " }\n");
}
}
strings::StrAppend(&class_decl, "\n");
if (output_types.empty()) {
strings::StrAppend(&class_decl, " Operation operation;\n");
}
for (int i = 0; i < output_types.size(); ++i) {
strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
";\n");
}
strings::StrAppend(&class_decl, "};\n\n");
TF_CHECK_OK(h->Append(class_decl));
}
void OpInfo::GetOutput(string* out) const {
const string scope_str = arg_names[0];
string return_on_error =
strings::StrCat("if (!", scope_str, ".ok()) return;");
// No outputs.
if (op_def.output_arg_size() == 0) {
strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
return;
}
if (op_def.output_arg_size() == 1) {
// One output, no need for NameRangeMap
if (is_list_output[0]) {
strings::StrAppend(out,
" for (int64 i = 0; i < ret->num_outputs(); ++i)\n");
strings::StrAppend(out, " this->", output_names[0],
".push_back(Output(ret, i));\n");
} else {
strings::StrAppend(out, " this->", output_names[0],
" = Output(ret, 0);\n");
}
return;
}
strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n");
strings::StrAppend(
out,
" ::tensorflow::Status _status_ = "
"::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), "
"nullptr, &_outputs_range);\n");
strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str,
".UpdateStatus(_status_);\n", " return;\n");
strings::StrAppend(out, " }\n\n");
for (int i = 0; i < op_def.output_arg_size(); ++i) {
const string arg_range = strings::StrCat(
"_outputs_range[\"", op_def.output_arg(i).name(), "\"]");
if (is_list_output[i]) {
strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ",
arg_range, ".second; ++i)\n");
strings::StrAppend(out, " this->", output_names[i],
".push_back(Output(ret, i));\n");
} else {
strings::StrAppend(out, " this->", output_names[i], " = Output(ret, ",
arg_range, ".first);\n");
}
}
}
string OpInfo::GetConstructorBody() const {
const string scope_str = arg_names[0];
string body;
string return_on_error =
strings::StrCat("if (!", scope_str, ".ok()) return;");
strings::StrAppend(&body, " ", return_on_error, "\n");
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::",
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(",
scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n");
strings::StrAppend(&body, " ", return_on_error, "\n");
}
strings::StrAppend(&body, " ::tensorflow::Node* ret;\n");
strings::StrAppend(&body, " const auto unique_name = ", scope_str,
".GetUniqueNameForOp(\"", op_def.name(), "\");\n");
strings::StrAppend(
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
op_def.name(), "\")\n");
const string spaces = " ";
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
}
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
continue;
}
const string attr_name = attr.has_default_value()
? strings::StrCat("attrs.", attr.name(), "_")
: AvoidCPPKeywords(attr.name());
strings::StrAppend(&body, spaces, ".Attr(\"", attr.name(), "\", ",
attr_name, ")\n");
}
strings::StrAppend(&body, " ;\n");
strings::StrAppend(&body, " ", scope_str, ".UpdateBuilder(&builder);\n");
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
scope_str, ".graph(), &ret));\n");
GetOutput(&body);
return body;
}
void OpInfo::WriteClassDef(WritableFile* cc) const {
string class_def;
strings::StrAppend(&class_def,
GetConstructorDecl(strings::StrCat(op_name, "::"),
/* include_attr */ true),
" {\n");
strings::StrAppend(&class_def, GetConstructorBody());
strings::StrAppend(&class_def, "}\n\n");
if (has_optional_attrs) {
strings::StrAppend(&class_def,
GetConstructorDecl(strings::StrCat(op_name, "::"),
/* include_attr */ false));
strings::StrAppend(&class_def, "\n : ", op_name, "(");
int i = 0;
for (; i < arg_names.size(); ++i) {
if (i > 0) strings::StrAppend(&class_def, ", ");
strings::StrAppend(&class_def, arg_names[i]);
}
if (i > 0) strings::StrAppend(&class_def, ", ");
strings::StrAppend(&class_def, op_name, "::Attrs()");
strings::StrAppend(&class_def, ") {}\n\n");
}
TF_CHECK_OK(cc->Append(class_def));
}
void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
OpInfo op_info(op_def);
op_info.WriteClassDecl(h);
op_info.WriteClassDef(cc);
}
} // namespace
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
const std::string& dot_cc_fname) {
Env* env = Env::Default();
std::unique_ptr<WritableFile> h = nullptr;
std::unique_ptr<WritableFile> cc = nullptr;
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
const string header =
R"header(// This file is MACHINE GENERATED! Do not edit.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
)header";
// TODO(keveman): Make namespaces configurable.
const string namespace_begin = R"namespace(
namespace tensorflow {
namespace ops {
)namespace";
const string footer = R"footer(} // namespace ops
} // namespace tensorflow
)footer";
const string op_header = GetPath(dot_h_fname);
const string op_header_guard = ToGuard(op_header);
const string cc_header = strings::StrCat(
R"include(// This file is MACHINE GENERATED! Do not edit.
#include "tensorflow/cc/ops/const_op.h"
)include",
"#include \"", op_header, "\"\n", namespace_begin);
TF_CHECK_OK(h->Append(
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
"#ifndef ",
op_header_guard,
"\n"
"#define ",
op_header_guard, "\n\n")));
TF_CHECK_OK(h->Append(header));
TF_CHECK_OK(h->Append(namespace_begin));
TF_CHECK_OK(cc->Append(cc_header));
for (const auto& op_def : ops.op()) {
WriteCCOp(op_def, h.get(), cc.get());
}
TF_CHECK_OK(h->Append(footer));
TF_CHECK_OK(
h->Append(strings::StrCat("\n#endif ", "// ", op_header_guard, "\n")));
TF_CHECK_OK(cc->Append(footer));
TF_CHECK_OK(cc->Close());
TF_CHECK_OK(h->Close());
}
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_OPS_CC_OP_GEN_H_
#define TENSORFLOW_CC_OPS_CC_OP_GEN_H_
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#include "tensorflow/core/framework/op_def.pb.h"
@ -26,4 +26,4 @@ void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
} // namespace tensorflow
#endif // TENSORFLOW_CC_OPS_CC_OP_GEN_H_
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/cc_op_gen.h"
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"

View File

@ -0,0 +1,229 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/ops/test_op.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
Output Linear(const Scope& scope, Input x, Input w, Input b) {
auto cop_scopes = scope.GetCompositeOpScopes("linear");
auto m = MatMul(cop_scopes.child, x, w);
return BiasAdd(cop_scopes.last, m, b);
}
void GetTensors(const Scope& scope, OutputList tensors,
std::vector<Tensor>* out) {
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
GraphDef def;
scope.graph()->ToGraphDef(&def);
graph::SetDefaultDevice("/cpu:0", &def);
TF_CHECK_OK(session->Create(def));
std::vector<string> names;
for (const auto& t : tensors) {
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
}
TF_CHECK_OK(session->Run({}, names, {}, out));
TF_CHECK_OK(session->Close());
}
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
std::vector<Tensor> outputs;
GetTensors(scope, {tensor}, &outputs);
*out = outputs[0];
}
void GetColocationConstraints(Output tensor, std::vector<string>* constraints) {
constraints->clear();
const auto& attrs = tensor.op().node()->def().attr();
ASSERT_TRUE(attrs.find("_class") != attrs.end());
auto loc = attrs.find("_class")->second;
TF_EXPECT_OK(AttrValueHasType(loc, "list(string)"));
if (loc.value_case() == AttrValue::kList && loc.list().s_size() > 0) {
for (int i = 0; i < loc.list().s_size(); ++i) {
if (loc.list().s(i).find("loc:@") == 0) {
constraints->push_back(loc.list().s(i));
}
}
}
}
} // namespace
TEST(CCOpTest, Basic) {
Scope root = Scope::NewRootScope();
auto c = Const(root, {{1, 1}});
// NOTE: The recommended style for constructing ops is
// auto v = OpConstructor(t0, t1, ..);
// Since the wrappers are implemented as one class per op, the following
// style is also possible :
// PrimitiveOp p(t0, t1, ...);
// It's being used here ONLY to ensure that, that style is tested.
MatMul m(root, c, {{41}, {1}});
TF_EXPECT_OK(root.status());
Tensor out;
GetTensor(root, m, &out);
test::ExpectTensorEqual<int>(out, test::AsTensor<int>({42}, {1, 1}));
}
TEST(CCOpTest, Attrs) {
Scope root = Scope::NewRootScope();
auto m = MatMul(root, {{1}, {1}}, {{41}, {1}}, MatMul::TransposeA(true));
TF_EXPECT_OK(root.status());
Tensor out;
GetTensor(root, m, &out);
test::ExpectTensorEqual<int>(out, test::AsTensor<int>({42}, {1, 1}));
}
TEST(CCOpTest, SplitConcat) {
Scope root = Scope::NewRootScope();
Split p(root, 0, {{1}, {2}}, 2);
auto c = Concat(root, 0, {p[0], p[1]});
TF_EXPECT_OK(root.status());
Tensor out;
GetTensor(root, c, &out);
test::ExpectTensorEqual<int>(out, test::AsTensor<int>({1, 2}, {2, 1}));
}
TEST(CCOpTest, CompositeOp) {
Scope root = Scope::NewRootScope();
auto l = Linear(root.WithOpName("layer0"), {{10.0f, -3.0f}},
{{.8f, .5f}, {.1f, .6f}}, {-8.0f, 31.0f});
TF_EXPECT_OK(root.status());
EXPECT_EQ(l.node()->name(), "layer0");
Tensor out;
GetTensor(root, l, &out);
test::ExpectClose(out, test::AsTensor<float>({-0.3, 34.2}, {1, 2}));
}
TEST(CCOpTest, MultiOutput) {
Scope root = Scope::NewRootScope();
auto u = Unique(root, {1, 2, 2, 4, 3, 2});
std::vector<Tensor> outputs;
GetTensors(root, {u.y, u.idx}, &outputs);
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({1, 2, 4, 3}));
test::ExpectTensorEqual<int>(outputs[1],
test::AsTensor<int>({0, 1, 1, 2, 3, 1}));
}
TEST(CCOpTest, ExampleTrainer) {
Scope root = Scope::NewRootScope();
// a = [3 2; -1 0]
auto a = Const(root, {{3.f, 2.f}, {-1.f, 0.f}});
// x = [1.0; 1.0]
auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
// y = a * x
auto y = MatMul(root.WithOpName("y"), a, x);
// y2 = y.^2
auto y2 = Square(root, y);
// y2_sum = sum(y2)
auto y2_sum = Sum(root, y2, 0);
// y_norm = sqrt(y2_sum)
auto y_norm = Sqrt(root, y2_sum);
// y_normalized = y ./ y_norm
auto y_normalized = Div(root.WithOpName("y_normalized"), y, y_norm);
Tensor out;
GetTensor(root, y_normalized, &out);
test::ExpectTensorNear<float>(
out, test::AsTensor<float>({0.98058069, -0.19611613}, {2, 1}), 1e-5);
}
TEST(CCOpTest, ThrowAwayOp) {
Scope root = Scope::NewRootScope();
ThrowAway1(root, 1, 2.3f, 1, 1, 1, ThrowAway1::Builder(42));
ThrowAway2(root, ThrowAway2::ThrowAway2_(3).Scope(1));
TF_EXPECT_OK(root.status());
}
TEST(CCOpTest, ControlDeps) {
Scope root = Scope::NewRootScope();
auto v = Variable(root, {}, DT_FLOAT);
auto assign = Assign(root, v, 41.0f);
Scope with_control_deps = root.WithControlDependencies(assign);
auto add = Add(with_control_deps, v, 1.0f);
Scope no_control_deps = with_control_deps.WithNoControlDependencies();
auto sub = Sub(no_control_deps, 3.0f, 2.0f);
auto is_inited =
IsVariableInitialized(no_control_deps.WithControlDependencies(sub), v);
TF_EXPECT_OK(root.status());
std::vector<Tensor> out;
GetTensors(root, {add}, &out);
test::ExpectTensorNear<float>(out[0], test::AsTensor<float>({42.0f}, {}),
1e-5);
out.clear();
// Note : GetTensors creates a new session, so 'v' is uninitialized.
// sub should have no control deps, so it should not cause the assign to run.
// Hence is_inited should be false.
GetTensors(root, {sub, is_inited}, &out);
test::ExpectTensorNear<float>(out[0], test::AsTensor<float>({1.0f}, {}),
1e-5);
test::ExpectTensorEqual<bool>(out[1], test::AsTensor<bool>({false}, {}));
}
TEST(CCOpTest, KernelLabel) {
Scope root = Scope::NewRootScope();
auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f);
TF_EXPECT_OK(root.status());
const auto& attrs = add.z.op().node()->def().attr();
ASSERT_TRUE(attrs.find("_kernel") != attrs.end());
auto kernel_attr = attrs.find("_kernel")->second;
TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string"));
EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel");
}
TEST(CCOpTest, ColocateWith) {
Scope root = Scope::NewRootScope();
auto c1 = Const(root.WithOpName("c1"), 1);
auto c2 = Const(root.WithOpName("c2").ColocateWith(c1), 2);
std::vector<string> constraints;
GetColocationConstraints(c2, &constraints);
EXPECT_EQ(constraints[0], "loc:@c1");
auto c3 = Const(root.WithOpName("c3").ColocateWith(c2), 3);
GetColocationConstraints(c3, &constraints);
EXPECT_EQ(constraints[0], "loc:@c1");
auto a = Const(root.WithOpName("a"), 4);
auto c4 = Const(root.WithOpName("c4").ColocateWith(a), 5);
GetColocationConstraints(c4, &constraints);
EXPECT_EQ(constraints[0], "loc:@a");
auto c5 = Const(root.WithOpName("c5").ColocateWith(c3).ColocateWith(c4), 6);
GetColocationConstraints(c5, &constraints);
EXPECT_EQ(constraints[0], "loc:@a");
EXPECT_EQ(constraints[1], "loc:@c1");
Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4);
auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7);
const auto& attrs = c6.op().node()->def().attr();
EXPECT_TRUE(attrs.find("_class") == attrs.end());
}
} // namespace tensorflow

View File

@ -0,0 +1,70 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/ops.h"
namespace tensorflow {
namespace ops {
Input::Initializer::Initializer(
const std::initializer_list<Input::Initializer>& v) {
if (v.size() < 1) {
// Empty initializer list defaults to float tensor with shape (0,)
tensor = Tensor(DT_FLOAT, TensorShape{0});
return;
}
auto const& first = *v.begin();
// Check to make sure that the constituent Initializers are all the same
// type and same shape.
for (auto const& e : v) {
if (e.tensor.dtype() != first.tensor.dtype()) {
status = errors::InvalidArgument(
"Initializer list components should all have the same type");
return;
}
if (!TensorShape{e.tensor.shape()}.IsSameSize(
TensorShape{first.tensor.shape()})) {
status = errors::InvalidArgument(
"Initializer list components should all have the same shape");
return;
}
}
// Form the new shape.
TensorShape shape{static_cast<int64>(v.size())};
shape.AppendShape(TensorShape{first.tensor.shape()});
Tensor t(first.tensor.dtype(), shape);
// Collate the constituent Tensors.
size_t offset = 0;
for (auto const& e : v) {
Tensor elem = e.tensor;
if (first.tensor.dtype() == DT_STRING) {
for (int i = 0; i < elem.NumElements(); ++i) {
t.flat<string>()(offset + i) = elem.flat<string>()(i);
}
offset += elem.NumElements();
} else {
std::copy_n(elem.tensor_data().data(), elem.TotalBytes(),
const_cast<char*>(t.tensor_data().data()) + offset);
offset += elem.TotalBytes();
}
}
tensor = t;
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,261 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
#include <type_traits>
#include "tensorflow/core/framework/tensor.h"
// TBD(keveman): This is going to be moved to //third_party/tensorflow
// eventually. Remove the NOLINT comment when moving.
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
namespace ops {
// Represents a node in the computation graph.
class Operation {
public:
Operation() : node_(nullptr) {}
explicit Operation(Node* n) : node_(n) {}
int num_outputs() const { return node_->num_outputs(); }
DataType output_type(int o) const { return node_->output_type(o); }
Node* node() const { return node_; }
private:
Node* node_;
};
// Represents a tensor value produced by an Operation.
class Output {
public:
Output() = default;
explicit Output(Node* n) : op_(n) {}
Output(Node* n, int64 index) : op_(n), index_(index) {}
Output(const Operation& op, int64 index) : op_(op), index_(index) {}
Operation op() const { return op_; }
Node* node() const { return op().node(); }
int64 index() const { return index_; }
DataType type() const { return op_.output_type(index_); }
private:
Operation op_ = Operation(nullptr);
int64 index_ = 0;
};
// Represents a tensor value that can be used as an operand to an Operation.
class Input {
public:
// Initializer enables constructing an Input object from various kinds of C++
// constants such as simple primitive constants and nested initializer lists
// representing a multi-dimensional array. Initializer constructors are all
// templates, so the aforementioned kinds of C++ constants can be used to
// construct an Initializer. Intializer stores the value it got constructed
// with in a Tensor object.
struct Initializer {
// Construct from a scalar value of an arithmetic type or a type that can be
// converted to a string (eg. a string literal).
template <typename T, typename = typename std::enable_if<
std::is_arithmetic<T>::value ||
std::is_convertible<T, string>::value>::type>
Initializer(const T& v) { // NOLINT(runtime/explicit)
typedef typename RealType<T>::type RealT;
Tensor t(DataTypeToEnum<RealT>::v(), TensorShape());
t.flat<T>()(0) = RealT(v);
tensor = t;
}
explicit Initializer(const Tensor& t) : tensor(t) {}
// Construct from a scalar value and an explicit shape
template <typename T, typename = typename std::enable_if<
std::is_arithmetic<T>::value ||
std::is_convertible<T, string>::value>::type>
Initializer(const T& v, const TensorShape& shape) {
typedef typename RealType<T>::type RealT;
Tensor t(DataTypeToEnum<RealT>::v(), shape);
for (int64 i = 0; i < t.NumElements(); ++i) {
t.flat<T>()(i) = RealT(v);
}
tensor = t;
}
// Construct from a initializer list of scalars (a one-dimensional tensor).
template <typename T, typename = typename std::enable_if<
std::is_arithmetic<T>::value ||
std::is_convertible<T, string>::value>::type>
Initializer(
const std::initializer_list<T>& v) { // NOLINT(runtime/explicit)
typedef typename RealType<T>::type RealT;
Tensor t(DataTypeToEnum<RealT>::v(),
TensorShape{static_cast<int>(v.size())});
std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
tensor = t;
}
// Construct from a initializer list of scalars and an explicit shape.
template <typename T, typename = typename std::enable_if<
std::is_arithmetic<T>::value ||
std::is_convertible<T, string>::value>::type>
Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
typedef typename RealType<T>::type RealT;
Tensor t(DataTypeToEnum<RealT>::v(), shape);
if (t.NumElements() != v.size()) {
status = errors::InvalidArgument(
"Cannot construct a tensor with ", t.NumElements(),
" from an initializer list with ", v.size(), " elements");
return;
}
std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
tensor = t;
}
// Construct a multi-dimensional tensor from a nested initializer list. Note
// that C++ syntax allows nesting of arbitrarily typed intializer lists, so
// such invalid initializers cannot be disallowed at compile time. This
// function performs checks to make sure that the nested initializer list is
// indeed a valid multi-dimensional tensor.
Initializer(const std::initializer_list<Initializer>& v);
template <typename T, bool = std::is_convertible<T, string>::value>
struct RealType {
typedef string type;
};
template <typename T>
struct RealType<T, false> {
typedef T type;
};
TensorProto AsTensorProto() {
TensorProto tensor_proto;
if (tensor.NumElements() > 1) {
tensor.AsProtoTensorContent(&tensor_proto);
} else {
tensor.AsProtoField(&tensor_proto);
}
return tensor_proto;
}
Status status;
Tensor tensor;
};
// All of Input's constructors are implicit. Input can be implicitly
// constructed from the following objects :
// * Output: This is so that the output of an Operation can be directly used
// as the input to a op wrapper, which takes Inputs.
// * A scalar, or a multi-dimensional tensor specified as a recursive
// initializer list. This enables directly passing constants as
// inputs to op wrappers.
Input(const Output& o) : output_(o) {} // NOLINT(runtime/explicit)
template <typename T, typename = typename std::enable_if<
std::is_arithmetic<T>::value ||
std::is_convertible<T, string>::value>::type>
Input(const T& v) // NOLINT(runtime/explicit)
: Input(Initializer(v)) {}
Input(const Initializer& init) // NOLINT(runtime/explicit)
: status_(init.status),
tensor_(init.tensor) {}
Input(const Tensor& t) // NOLINT(runtime/explicit)
: status_(Status::OK()),
tensor_(t) {}
Input(const std::initializer_list<Initializer>&
init) { // NOLINT(runtime/explicit)
for (const auto& i : init) {
if (!i.status.ok()) {
status_ = i.status;
return;
}
}
tensor_ = Initializer(init).tensor;
}
// Constructor specifying a node name, index and datatype. This should only be
// used for specifying a backward edge, needed by control flow.
Input(const string& name, int i, DataType dt)
: node_name_(name), index_(i), data_type_(dt) {}
Node* node() const { return output_.node(); }
string node_name() const { return node_name_; }
int index() const { return node_name_.empty() ? output_.index() : index_; }
DataType data_type() const { return data_type_; }
Status status() const { return status_; }
const Tensor& tensor() const { return tensor_; }
private:
Status status_;
Output output_ = Output(Operation(nullptr), 0);
Tensor tensor_;
const string node_name_ = "";
int index_ = 0;
DataType data_type_ = DT_INVALID;
};
// A type for representing the output of ops that produce more than one output,
// or a list of tensors.
typedef std::vector<Output> OutputList;
// A type for representing the input to ops that require a list of tensors.
class InputList {
public:
// Implicitly convert a list of outputs to a list of inputs. This is useful to
// write code such as tf.Concat(tf.Split(x, 4)).
InputList(const OutputList& out) { // NOLINT(runtime/explicit)
for (auto const& x : out) {
inputs_.push_back(x);
}
}
InputList(
const std::initializer_list<Input>& inputs) // NOLINT(runtime/explicit)
: inputs_(inputs.begin(), inputs.end()) {}
InputList(const tensorflow::gtl::ArraySlice<Input>&
inputs) // NOLINT(runtime/explicit)
: inputs_(inputs.begin(), inputs.end()) {}
InputList(
const std::initializer_list<Output>& out) { // NOLINT(runtime/explicit)
for (auto const& x : out) {
inputs_.push_back(x);
}
}
typename std::vector<Input>::iterator begin() { return inputs_.begin(); }
typename std::vector<Input>::iterator end() { return inputs_.end(); }
typename std::vector<Input>::const_iterator begin() const {
return inputs_.begin();
}
typename std::vector<Input>::const_iterator end() const {
return inputs_.end();
}
private:
std::vector<Input> inputs_;
};
} // namespace ops
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_

View File

@ -0,0 +1,347 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <vector>
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map)
: graph_(graph),
status_(status),
name_map_(name_map),
scope_used_(nullptr) {}
Scope Scope::NewRootScope() {
return Scope(new Graph(OpRegistry::Global()), new Status, new Scope::NameMap);
}
Scope::Scope(const Scope& other, Scope::Tags::ScopeName, const string& name,
bool copy_names)
: graph_(other.graph_),
status_(other.status_),
name_map_(copy_names ? other.name_map_
: std::shared_ptr<NameMap>(new NameMap)),
scope_used_(nullptr),
control_deps_(other.control_deps_),
name_(name),
op_name_(""),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::OpName, const string& name,
const string& op_name)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(name),
op_name_(op_name),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::ControlDeps,
std::vector<ops::Operation> control_deps, bool clear_control_deps)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
scope_used_(other.scope_used_),
control_deps_(clear_control_deps
? std::vector<ops::Operation>()
: (control_deps.insert(control_deps.begin(),
other.control_deps_.begin(),
other.control_deps_.end()),
control_deps)),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::Device, const string& device)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(device),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::SingleUseScope,
const string& op_name)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
scope_used_(new bool(false)),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(op_name),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::ExitOnError)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(true),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::KernelLabel,
const string& kernel_label)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
kernel_label_(kernel_label),
device_(other.device_),
colocation_constraints_(other.colocation_constraints_) {}
Scope::Scope(const Scope& other, Scope::Tags::Colocate,
const ops::Operation& colocate_with_op, bool clear_colocations)
: graph_(other.graph_),
status_(other.status_),
name_map_(other.name_map_),
scope_used_(other.scope_used_),
control_deps_(other.control_deps_),
name_(other.name_),
op_name_(other.op_name_),
exit_on_error_(other.exit_on_error_),
kernel_label_(other.kernel_label_),
device_(other.device_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.GetColocationConstraints(colocate_with_op)) {}
std::unordered_set<string> Scope::GetColocationConstraints(
const ops::Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
const NodeDef& node_def = colocate_with_op.node()->def();
if (node_def.attr().find("_class") != node_def.attr().end()) {
const AttrValue& loc = node_def.attr().find("_class")->second;
if (loc.value_case() == AttrValue::kList && loc.list().s_size() > 0) {
for (int i = 0; i < loc.list().s_size(); ++i) {
// Filter out the ones that don't have "loc:@" prefix
if (loc.list().s(i).find("loc:@") == 0) {
// Skip the "loc:@" prefix
current_constraints.insert(loc.list().s(i).substr(5));
}
}
}
} else {
current_constraints.insert(colocate_with_op.node()->name());
}
return current_constraints;
}
void Scope::UpdateStatus(const Status s) const {
status_->Update(s);
if (exit_on_error_ && !status_->ok()) {
LOG(FATAL) << status_;
}
}
Status Scope::ToGraphDef(GraphDef* gdef) const {
if (!status_->ok()) {
return *status_;
}
graph()->ToGraphDef(gdef);
return Status::OK();
}
Status Scope::ToGraph(Graph* g) const {
if (status_->ok()) {
GraphDef graph_def;
graph()->ToGraphDef(&graph_def);
GraphConstructorOptions opts;
UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
}
return *status_;
}
void Scope::UpdateBuilder(NodeBuilder* builder) const {
std::vector<Node*> control_inputs;
for (const auto& op : control_deps_) {
control_inputs.push_back(op.node());
}
builder->ControlInputs(control_inputs);
if (!kernel_label_.empty()) {
builder->Attr("_kernel", kernel_label_);
}
if (!colocation_constraints_.empty()) {
std::vector<string> constraints(colocation_constraints_.begin(),
colocation_constraints_.end());
// Sort the set.
std::sort(constraints.begin(), constraints.end());
// Add loc:@ prefix
std::transform(constraints.begin(), constraints.end(), constraints.begin(),
[](const string& s) { return strings::StrCat("loc:@", s); });
builder->Attr("_class", constraints);
}
if (!device_.empty()) {
builder->Device(device_);
}
}
string Scope::GetUniqueName(const string& prefix, bool check_single_use) const {
if (check_single_use && single_use_scope()) {
if (*scope_used_) {
*status_ =
errors::AlreadyExists(prefix, " already exists in the current scope");
return "";
}
*scope_used_ = true;
return prefix;
}
auto entry = name_map_->find(prefix);
string unique_name = prefix;
if (entry == name_map_->end()) {
name_map_->insert({prefix, 0});
} else {
unique_name = strings::StrCat(unique_name, "_", ++entry->second);
}
return unique_name;
}
string Scope::GetNameForOp(const string& default_name) const {
const string unique_name =
GetUniqueName(default_name, true /* check_single_use */);
const string sep = name_.empty() || unique_name.empty() ? "" : "/";
return strings::StrCat(name_, sep, unique_name);
}
string Scope::GetUniqueNameForOp(const string& default_name) const {
if (single_use_scope()) {
if (op_name_.empty() || *scope_used_) {
*status_ =
errors::InvalidArgument("Cannot get a unique name in this scope");
return "";
}
*scope_used_ = true;
return op_name_;
}
return op_name_.empty() ? GetNameForOp(default_name) : GetNameForOp(op_name_);
}
Scope Scope::NewSubScope(const string& child_scope_name) const {
if (child_scope_name.empty()) {
return Scope(*this, Scope::Tags::ScopeName(), name_, true /* copy_names */);
}
const string unique_name =
GetUniqueName(child_scope_name, false /* check_single_use */);
const string sep = name_.empty() || unique_name.empty() ? "" : "/";
return Scope(*this, Scope::Tags::ScopeName(),
strings::StrCat(name_, sep, unique_name),
false /* copy_names */);
}
Scope Scope::WithOpName(const string& op_name) const {
if (single_use_scope()) {
UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
" on this scope"));
return *this;
}
return Scope(*this, Scope::Tags::OpName(), name_, op_name);
}
Scope Scope::WithControlDependencies(
const gtl::ArraySlice<ops::Operation>& control_deps) const {
return Scope(
*this, Scope::Tags::ControlDeps(),
std::vector<ops::Operation>(control_deps.begin(), control_deps.end()),
/* clear_control_deps */ false);
}
Scope Scope::WithControlDependencies(const ops::Output& control_dep) const {
return Scope(*this, Scope::Tags::ControlDeps(),
std::vector<ops::Operation>(1, control_dep.op()),
/* clear_control_deps */ false);
}
Scope Scope::WithNoControlDependencies() const {
return Scope(*this, Scope::Tags::ControlDeps(), std::vector<ops::Operation>(),
/* clear_control_deps */ true);
}
Scope Scope::WithDevice(const string& device) const {
return Scope(*this, Scope::Tags::Device(), device);
}
Scope Scope::ColocateWith(const ops::Operation& op) const {
return Scope(*this, Scope::Tags::Colocate(), op,
/* clear_colocations */ false);
}
Scope Scope::ClearColocation() const {
return Scope(*this, Scope::Tags::Colocate(), ops::Operation(),
/* clear_colocations */ true);
}
Scope Scope::ExitOnError() const {
return Scope(*this, Scope::Tags::ExitOnError());
}
Scope Scope::WithKernelLabel(const string& kernel_label) const {
return Scope(*this, Scope::Tags::KernelLabel(), kernel_label);
}
CompositeOpScopes Scope::GetCompositeOpScopes(
const string& composite_op_name) const {
if (op_name_.empty() && composite_op_name.empty()) {
UpdateStatus(errors::InvalidArgument(
"Cannot create composite op scopes with empty name"));
return {*this, *this};
}
if (!single_use_scope()) {
Scope child = NewSubScope(op_name_.empty() ? composite_op_name : op_name_);
const string child_op_sep = name_.empty() ? "" : "_";
return {child, Scope(child, Scope::Tags::SingleUseScope(),
strings::StrCat(name_, child_op_sep, child.name_))};
} else {
return {
Scope(*this, Scope::Tags::ScopeName(), op_name_, true /* copy_names */),
*this};
}
}
} // namespace tensorflow

View File

@ -0,0 +1,260 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
class GraphDef;
class NodeBuilder;
struct CompositeOpScopes;
// A `Scope` object represents a set of related TensorFlow ops that have the
// same properties such as a common name prefix.
// A Scope object is a container for TensorFlow Op properties. Op constructors
// get a Scope object as a mandatory first argument and the constructed op
// acquires the properties in the object.
//
// A simple example:
//
// using namespace ops;
// Scope root = Scope::NewRootScope();
// auto c1 = Const(root, {{1, 1}});
// auto m = MatMul(root, c1, {{41}, {1}});
// GraphDef gdef;
// Status s = root.ToGraphDef(&gdef);
// if (!s.ok()) { /* Handle error */ }
//
// Scope hierarchy:
// The Scope class provides various With<> functions that create a new scope.
// The new scope typically has one property changed while other properties are
// inherited from the parent scope.
// NewSubScope(name) method appends `name` to the prefix of names for ops
// created within the scope, and WithOpName() changes the suffix which
// otherwise defaults to the type of the op.
//
// Name examples:
// Scope root = Scope::NewRootScope();
// Scope linear = root.NewSubScope("linear");
// /* W will be named "linear/W" */
// auto W = Variable(linear.WithOpName("W"),
// {2, 2}, DT_FLOAT);
// /* b will be named "linear/b" */
// auto b = Variable(linear.WithOpName("b"),
// {2}, DT_FLOAT);
// auto x = Const(linear, {...}); // name: "linear/Const"
// auto m = MatMul(linear, x, W); // name: "linear/MatMul"
// auto r = BiasAdd(linear, m, b); // name: "linear/BiasAdd"
//
// Scope lifetime:
// A new scope is created by calling Scope::NewRootScope. This creates some
// resources that are shared by all the child scopes that inherit from this
// scope, directly or transitively. For instance, a new scope creates a new
// Graph object to which operations are added when the new scope or its children
// are used by an Op constructor. The new scope also has a Status object which
// will be used to indicate errors by Op-constructor functions called on any
// child scope. The Op-constructor functions have to check the scope's status by
// calling the ok() method before proceeding to construct the op.
class Scope {
public:
// The following functions are for users making graphs. They return brand new
// scopes, or scopes derived from an existing scope object.
// Return a new scope.
// This creates a new graph and all operations constructed in this graph
// should use the returned object as the "root" scope.
static Scope NewRootScope();
// Return a new scope. Ops created with this scope will have
// <name>/<child_scope_name> as the prefix. The actual name will be unique
// in the current scope. All other properties are inherited from the current
// scope. If child_scope_name is empty, the '/' is elided.
Scope NewSubScope(const string& child_scope_name) const;
// Return a new scope. All ops created within the returned scope will have
// names of the form <name>/<op_name>[_<suffix].
Scope WithOpName(const string& op_name) const;
// Return a new scope. All ops created within the returned scope will have as
// control dependencies the union of operations in the control_deps vector and
// the control dependencies of the current scope.
Scope WithControlDependencies(
const gtl::ArraySlice<ops::Operation>& control_deps) const;
// Same as above, but convenient to add control dependency on the operation
// producing the control_dep output.
Scope WithControlDependencies(const ops::Output& control_dep) const;
// Return a new scope. All ops created within the returned scope will have no
// control dependencies on other operations.
Scope WithNoControlDependencies() const;
// Return a new scope. All ops created within the returned scope will have the
// device field set to 'device'.
Scope WithDevice(const string& device) const;
// Return a new scope. All ops created within the returned scope will be
// co-located on the device where op is placed.
// NOTE: This function is intended to be use internal libraries only for
// controlling placement of ops on to devices. Public use is not encouraged
// because the implementation of device placement is subject to change.
Scope ColocateWith(const ops::Operation& op) const;
// Convenience function for above.
Scope ColocateWith(const ops::Output& out) const {
return ColocateWith(out.op());
}
// Clear all colocation constraints.
Scope ClearColocation() const;
// Return a new scope. The op-constructor functions taking the returned scope
// as the scope argument will exit as soon as an error is detected, instead of
// setting the status on the scope.
Scope ExitOnError() const;
// Return a new scope. All ops created with the new scope will have
// kernel_label as the value for their '_kernel' attribute;
Scope WithKernelLabel(const string& kernel_label) const;
// The following functions are for scope object consumers.
// Return a unique name, using default_name if an op name has not been
// specified.
string GetUniqueNameForOp(const string& default_name) const;
// Update the status on this scope.
// Note: The status object is shared between all children of this scope.
// If the resulting status is not Status::OK() and exit_on_error_ is set on
// this scope, this function exits by calling LOG(FATAL).
void UpdateStatus(const Status s) const;
// Update the builder with properties accumulated in this scope.
void UpdateBuilder(NodeBuilder* builder) const;
CompositeOpScopes GetCompositeOpScopes(const string& composite_op_name) const;
bool ok() const { return status_->ok(); }
Graph* graph() const { return graph_.get(); }
Status status() const { return *status_; }
// If status() is Status::OK(), convert the Graph object stored in this scope
// to a GraphDef proto and return Status::OK(). Otherwise, return the error
// status as is without performing GraphDef conversion.
Status ToGraphDef(GraphDef* gdef) const;
// If status() is Status::OK(), construct a Graph object using the default
// GraphConstructorOptions, and return Status::OK if graph construction was
// successful. Otherwise, return the error status.
// TODO(josh11b, keveman): Make this faster; right now it converts
// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
// edges from the source and to the sink node, resolves back edges
// by name), and makes sure the resulting graph is valid.
Status ToGraph(Graph* g) const;
const std::vector<ops::Operation>& control_deps() const {
return control_deps_;
}
private:
// Tag types to choose the constructor to dispatch.
struct Tags {
enum class ScopeName;
enum class OpName;
enum class ControlDeps;
enum class Device;
enum class SingleUseScope;
enum class ExitOnError;
enum class KernelLabel;
enum class Colocate;
};
// A NameMap is used to keep track of suffixes for names used in a scope. A
// name that has not been used so far in a scope will get no suffix. Later
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
// can be sharing the same NameMap. For instance, a new scope created using
// WithControlDependencies() should would share the same NameMap with the
// parent.
typedef std::unordered_map<string, int> NameMap;
Scope(Graph* graph, Status* status, NameMap* name_map);
Scope(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names);
Scope(const Scope& other, Tags::OpName, const string& name,
const string& op_name);
Scope(const Scope& other, Tags::ControlDeps,
std::vector<ops::Operation> control_deps, bool clear_control_deps);
Scope(const Scope& other, Tags::Device, const string& device);
Scope(const Scope& other, Tags::SingleUseScope, const string& op_name);
Scope(const Scope& other, Tags::ExitOnError);
Scope(const Scope& other, Tags::KernelLabel, const string& kernel_label);
Scope(const Scope& other, Tags::Colocate,
const ops::Operation& colocate_with_op, bool clear_colocations);
std::unordered_set<string> GetColocationConstraints(
const ops::Operation& colocate_with_op) const;
// Helper functions to get a unique names.
string GetUniqueName(const string& prefix, bool check_single_use) const;
string GetNameForOp(const string& default_name) const;
bool single_use_scope() const { return scope_used_ != nullptr; }
// The graph, status, and name maps are shared by all child scopes
// created from a single 'root' scope. A root scope is created by calling the
// Scope::NewRootScope function, which creates a new graph, a new status and
// the name maps.
std::shared_ptr<Graph> graph_ = nullptr;
std::shared_ptr<Status> status_ = nullptr;
std::shared_ptr<NameMap> name_map_ = nullptr;
// If scope_used_ is not nullptr, op_name_ should be empty and
// GetUniqueNameForOp can only be called once on this scope. More calls to
// GetUniqueNameForOp will cause an error status to be set on this scope.
std::shared_ptr<bool> scope_used_ = nullptr;
const std::vector<ops::Operation> control_deps_;
const string name_ = "";
const string op_name_ = "";
const bool exit_on_error_ = false;
const string kernel_label_ = "";
const string device_ = "";
const std::unordered_set<string> colocation_constraints_;
};
// A helper struct to hold the scopes that would be used by a function
// constructing a composite op.
struct CompositeOpScopes {
// Scope to be used for creating the local ops (primitive or other composite
// ops).
Scope child;
// Scope to be used for creating the last op.
Scope last;
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_

View File

@ -0,0 +1,138 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
TEST(ScopeTest, BasicNames) {
Scope root = Scope::NewRootScope();
EXPECT_EQ(root.GetUniqueNameForOp("add"), "add");
EXPECT_EQ(root.GetUniqueNameForOp("add"), "add_1");
EXPECT_EQ(root.GetUniqueNameForOp("add"), "add_2");
EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul");
}
TEST(ScopeTest, HierarchicalNames) {
Scope root = Scope::NewRootScope();
Scope child = root.NewSubScope("child");
EXPECT_EQ(child.GetUniqueNameForOp("add"), "child/add");
EXPECT_EQ(child.GetUniqueNameForOp("add"), "child/add_1");
EXPECT_EQ(child.GetUniqueNameForOp("mul"), "child/mul");
Scope child_1 = root.NewSubScope("child");
EXPECT_EQ(child_1.GetUniqueNameForOp("add"), "child_1/add");
EXPECT_EQ(child_1.GetUniqueNameForOp("add"), "child_1/add_1");
EXPECT_EQ(child_1.GetUniqueNameForOp("mul"), "child_1/mul");
Scope c_c = root.NewSubScope("c").NewSubScope("c");
EXPECT_EQ(c_c.GetUniqueNameForOp("add"), "c/c/add");
Scope c_1 = root.NewSubScope("c");
Scope c_1_c = c_1.NewSubScope("c");
EXPECT_EQ(c_1_c.GetUniqueNameForOp("add"), "c_1/c/add");
Scope c_1_c_1 = c_1.NewSubScope("c");
EXPECT_EQ(c_1_c_1.GetUniqueNameForOp("add"), "c_1/c_1/add");
EXPECT_EQ(root.NewSubScope("").NewSubScope("").GetUniqueNameForOp("d"), "d");
EXPECT_EQ(root.NewSubScope("").GetUniqueNameForOp("d"), "d_1");
EXPECT_EQ(root.GetUniqueNameForOp("d"), "d_2");
}
TEST(ScopeTest, ScopeAndOpNames) {
Scope root = Scope::NewRootScope();
Scope child = root.NewSubScope("child");
EXPECT_EQ(child.GetUniqueNameForOp("add"), "child/add");
EXPECT_EQ(root.GetUniqueNameForOp("child"), "child_1");
EXPECT_EQ(root.NewSubScope("child").GetUniqueNameForOp("p"), "child_2/p");
}
namespace {
string LastOp(const Scope& scope) { return scope.GetUniqueNameForOp("Last"); }
std::vector<string> AnotherCompositeOp(const Scope& scope) {
auto cop_scopes = scope.GetCompositeOpScopes("another_cop");
const string c1 = cop_scopes.child.GetUniqueNameForOp("c1");
const string c2 = cop_scopes.child.GetUniqueNameForOp("mul");
return {c1, c2, LastOp(cop_scopes.last)};
}
std::vector<string> LinearOp(const Scope& scope) {
auto cop_scopes = scope.GetCompositeOpScopes("linear");
Scope linear = cop_scopes.child;
const string mul_op_name = linear.GetUniqueNameForOp("mul");
const string bias_add_op_name = linear.GetUniqueNameForOp("bias_add");
auto cop_names = AnotherCompositeOp(cop_scopes.last);
return {mul_op_name, bias_add_op_name, cop_names[0], cop_names[1],
cop_names[2]};
}
} // namespace
TEST(ScopeTest, CompositeOp) {
Scope root = Scope::NewRootScope();
const auto names1 = LinearOp(root);
EXPECT_EQ(names1[0], "linear/mul");
EXPECT_EQ(names1[1], "linear/bias_add");
EXPECT_EQ(names1[2], "linear/c1");
EXPECT_EQ(names1[3], "linear/mul_1");
EXPECT_EQ(names1[4], "linear");
EXPECT_EQ(root.GetUniqueNameForOp("linear"), "linear_1");
const auto names2 = LinearOp(root);
EXPECT_EQ(names2[0], "linear_2/mul");
EXPECT_EQ(names2[1], "linear_2/bias_add");
EXPECT_EQ(names2[2], "linear_2/c1");
EXPECT_EQ(names2[3], "linear_2/mul_1");
EXPECT_EQ(names2[4], "linear_2");
const auto names3 = LinearOp(root.WithOpName("c"));
EXPECT_EQ(names3[0], "c/mul");
EXPECT_EQ(names3[1], "c/bias_add");
EXPECT_EQ(names3[2], "c/c1");
EXPECT_EQ(names3[3], "c/mul_1");
EXPECT_EQ(names3[4], "c");
}
TEST(ScopeTest, SingleUseScope) {
Scope root = Scope::NewRootScope();
auto cop_scopes = root.GetCompositeOpScopes("cop");
// cop_scopes.last is a single use scope
EXPECT_EQ(cop_scopes.last.GetUniqueNameForOp("foo"), "cop");
cop_scopes.last.GetUniqueNameForOp("foo");
// Error status should be set on cop_scopes.last
EXPECT_FALSE(cop_scopes.last.ok());
}
TEST(ScopeTest, ControlDeps) {
Scope root = Scope::NewRootScope();
auto c1 = ops::Operation();
auto c2 = ops::Operation();
Scope c = root.WithControlDependencies({c1, c2});
EXPECT_EQ(c.control_deps().size(), 2);
Scope c_c = c.WithControlDependencies({ops::Operation()});
EXPECT_EQ(c_c.control_deps().size(), 3);
}
} // namespace tensorflow

View File

@ -0,0 +1,47 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("ThrowAway1")
.Input("ret: int32")
.Input("unique_name: float")
.Input("for: int32")
.Attr("scope: int")
.Attr("builder: int = 1")
.Attr("while: int")
.Doc(R"doc(
Op to test keywords and reserved words in input and attr names.
ret: Return value.
for: Keyword as name for input.
while: Keyword as name for attr.
)doc");
REGISTER_OP("ThrowAway2")
.Attr("scope: int = 2")
.Attr("throw_away2: int = 2")
.Attr("attrs: int = 4")
.Attr("node: int = 4");
REGISTER_OP("ThrowAway3").Output("node: int32");
REGISTER_OP("ThrowAway4").Input("node: int32");
REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4");
} // namespace tensorflow

View File

@ -1,382 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// TODO(josh11b): Rewrite function parameter names to avoid C++ keywords
// or "opts".
#include "tensorflow/cc/ops/cc_op_gen.h"
#include <unordered_map>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
const int kRightMargin = 79;
const char* AttrTypeName(StringPiece attr_type) {
static const char* kAttrTypeName[][2] = {
{"string", "StringPiece"},
{"list(string)", "gtl::ArraySlice<string>"},
{"int", "int64"},
{"list(int)", "gtl::ArraySlice<int>"},
{"float", "float"},
{"list(float)", "gtl::ArraySlice<float>"},
{"bool", "bool"},
{"list(bool)", "gtl::ArraySlice<bool>"},
{"type", "DataType"},
{"list(type)", "DataTypeSlice"},
{"shape", "TensorShape"},
{"list(shape)", "gtl::ArraySlice<TensorShape>"},
{"tensor", "const Tensor&"},
{"list(tensor)", "gtl::ArraySlice<Tensor>"},
{"func", "const NameAttrList&"},
};
for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
if (attr_type == kAttrTypeName[i][0]) {
return kAttrTypeName[i][1];
}
}
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
return "";
}
// Change: Into:
// ABC // ABC
// //
// DEF // DEF
string MakeComment(StringPiece text) {
string ret;
while (!text.empty()) {
int last_non_space = -1;
int newline;
for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
if (text[newline] == '\n') break;
if (text[newline] != ' ') last_non_space = newline;
}
if (last_non_space == -1) {
strings::StrAppend(&ret, "//\n");
} else {
strings::StrAppend(&ret, "// ", text.substr(0, last_non_space + 1), "\n");
}
text.remove_prefix(newline + 1);
}
return ret;
}
void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
// TODO(josh11b): Better wrapping of comments.
string comment;
if (op_def.summary().empty()) {
comment = "TODO: add doc.\n";
} else {
comment = strings::StrCat(op_def.summary(), "\n");
if (op_def.has_deprecation()) {
strings::StrAppend(&comment, "\nDEPRECATED at GraphDef version ",
op_def.deprecation().version(), ":\n",
op_def.deprecation().explanation(), ".\n");
}
if (!op_def.description().empty()) {
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
}
}
static const string kSingleInputType = "NodeOut";
static const string kListInputType = "gtl::ArraySlice<NodeOut>";
std::vector<string> arg_types;
std::vector<string> arg_names;
strings::StrAppend(&comment, "\nArguments:\n");
// Map from attr name to the first input arg it is inferred from.
std::unordered_map<string, string> inferred_attrs;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
arg_names.emplace_back(arg.name());
bool is_list = false;
if (!arg.type_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
} else if (!arg.type_list_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
arg.name());
is_list = true;
}
if (!arg.number_attr().empty()) {
gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
is_list = true;
}
if (is_list) {
arg_types.emplace_back(kListInputType);
} else {
arg_types.emplace_back(kSingleInputType);
}
// TODO(josh11b): Include input type information.
StringPiece description = arg.description();
if (!description.empty()) {
ConsumeEquals(&description);
strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
arg.description(), "\n");
}
}
string options_comment;
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
// Do not add inferred attrs or attrs with defaults to the C++
// function signature.
if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
if (!attr.has_default_value()) {
arg_names.emplace_back(attr.name());
arg_types.emplace_back(AttrTypeName(attr.type()));
if (!attr.description().empty()) {
strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
attr.description(), "\n");
}
} else {
strings::StrAppend(&options_comment, " .WithAttr(\"", attr.name(),
"\", ", AttrTypeName(attr.type()), "): Defaults to ",
SummarizeAttrValue(attr.default_value()), ".\n");
if (!attr.description().empty()) {
strings::StrAppend(&options_comment, " ", attr.description(),
"\n");
}
}
}
}
CHECK_EQ(arg_names.size(), arg_types.size());
strings::StrAppend(&comment, "* opts:\n", options_comment,
R"comment( .WithName(StringPiece): Set the Node's name
.WithDevice(StringPiece): Set the Node's requested device
.WithControlInput(Node*) / .WithControlInputs({Node*, ...}):
Add control dependencies on the specified Node(s).
Returns a pointer to the created Node)comment");
// TODO(josh11b): Include output type information.
if (op_def.output_arg_size() == 0) {
strings::StrAppend(&comment, ".\n");
} else if (op_def.output_arg_size() == 1) {
StringPiece description = op_def.output_arg(0).description();
ConsumeEquals(&description);
if (description.empty()) {
strings::StrAppend(&comment, ".\n");
} else {
strings::StrAppend(&comment, ", with output:\n", description, "\n");
}
} else {
strings::StrAppend(&comment, ", with outputs:\n");
for (int o = 0; o < op_def.output_arg_size(); ++o) {
StringPiece description = op_def.output_arg(o).description();
ConsumeEquals(&description);
if (description.empty()) {
strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), "\n");
} else {
strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), ": ",
description, "\n");
}
}
}
// Write the header comment.
TF_CHECK_OK(h->Append(MakeComment(comment)));
// Declare the function wrapper.
const string prefix = strings::StrCat("Node* ", op_def.name(), "(");
string h_rest;
for (size_t i = 0; i < arg_names.size(); ++i) {
strings::StrAppend(&h_rest, arg_types[i], " ", arg_names[i], ", ");
}
strings::StrAppend(&h_rest, "const GraphDefBuilder::Options& opts");
string cc_decl = h_rest;
strings::StrAppend(&h_rest, ");");
TF_CHECK_OK(h->Append(WordWrap(prefix, h_rest, kRightMargin) + "\n\n"));
// Define the function wrapper.
strings::StrAppend(&cc_decl, ") {");
TF_CHECK_OK(cc->Append(WordWrap(prefix, cc_decl, kRightMargin) + "\n"));
const string op_name = strings::StrCat(" static const string kOpName = \"",
op_def.name(), "\";\n");
if (arg_types.empty()) {
TF_CHECK_OK(cc->Append(op_name));
TF_CHECK_OK(cc->Append(" return SourceOp(kOpName, opts);\n}\n\n"));
} else if (arg_types == std::vector<string>({kSingleInputType})) {
TF_CHECK_OK(cc->Append(op_name));
TF_CHECK_OK(cc->Append(strings::StrCat(" return UnaryOp(kOpName, ",
arg_names[0], ", opts);\n}\n\n")));
} else if (arg_types ==
std::vector<string>({kSingleInputType, kSingleInputType})) {
TF_CHECK_OK(cc->Append(op_name));
// TODO(josh11b): Word wrap this if it ever becomes necessary.
TF_CHECK_OK(
cc->Append(strings::StrCat(" return BinaryOp(kOpName, ", arg_names[0],
", ", arg_names[1], ", opts);\n}\n\n")));
} else {
TF_CHECK_OK(cc->Append(" if (opts.HaveError()) return nullptr;\n"));
TF_CHECK_OK(cc->Append(op_name));
TF_CHECK_OK(cc->Append(
" NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,\n"
" opts.op_registry());\n"));
for (size_t i = 0; i < arg_names.size(); ++i) {
if (i < static_cast<size_t>(op_def.input_arg_size())) {
TF_CHECK_OK(cc->Append(
strings::StrCat(" node_builder.Input(", arg_names[i], ");\n")));
} else {
TF_CHECK_OK(
cc->Append(strings::StrCat(" node_builder.Attr(\"", arg_names[i],
"\", ", arg_names[i], ");\n")));
}
}
TF_CHECK_OK(
cc->Append(" return opts.FinalizeBuilder(&node_builder);\n"
"}\n\n"));
}
}
// Converts:
// bazel-out/.../genfiles/(external/YYY/)?XX
// to: XX.
string GetPath(const std::string& dot_h_fname) {
auto pos = dot_h_fname.find("/genfiles/");
string result = dot_h_fname;
if (pos != string::npos) {
// - 1 account for the terminating null character (\0) in "/genfiles/".
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
}
if (result.size() > sizeof("external/") &&
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
result = result.substr(sizeof("external/") - 1);
pos = result.find("/");
if (pos != string::npos) {
result = result.substr(pos + 1);
}
}
return result;
}
// Converts:
// cc/ops/gen_foo_ops.h
// to:
// CC_OPS_GEN_FOO_OPS_H_
string ToGuard(const std::string& path) {
string guard;
guard.reserve(path.size() + 1); // + 1 -> trailing _
for (const char c : path) {
if (c >= 'A' && c <= 'Z') {
guard += c;
} else if (c >= 'a' && c <= 'z') {
guard += c + 'A' - 'a';
} else {
guard += '_';
}
}
guard += '_';
return guard;
}
} // namespace
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
const std::string& dot_cc_fname) {
Env* env = Env::Default();
std::unique_ptr<WritableFile> h;
std::unique_ptr<WritableFile> cc;
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
// .h Header
const string include = GetPath(dot_h_fname);
const string guard = ToGuard(include);
// TODO(josh11b): Mention the library for which wrappers are being generated.
Status s;
s = h->Append(
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
"#ifndef ",
guard,
"\n"
"#define ",
guard, R"header(
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace ops {
// These add a node to the graph from opts.
//
// Note for "NodeOut" inputs, you will typically either pass
// * a {Node*, int index} (to pass the index-th output of that node), or
// * a Node* (to pass the first output of that node).
)header"));
TF_CHECK_OK(s);
// .cc Header
s = cc->Append(
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
"#include \"",
include, R"header("
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
namespace ops {
)header"));
TF_CHECK_OK(s);
for (const auto& op_def : ops.op()) {
WriteCCOp(op_def, h.get(), cc.get());
}
// .h Footer
s = h->Append(strings::StrCat(R"footer(} // namespace ops
} // namespace tensorflow
#endif // )footer",
guard, "\n"));
TF_CHECK_OK(s);
// .cc Footer
s = cc->Append(R"footer(} // namespace ops
} // namespace tensorflow
)footer");
TF_CHECK_OK(s);
TF_CHECK_OK(cc->Close());
TF_CHECK_OK(h->Close());
}
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,119 +14,59 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace ops {
namespace {
const string& OpName() {
static const string kOpName = "Const";
return kOpName;
}
} // namespace
#define DEFINE_CONST_SCALAR(TYPE) \
Node* Const(TYPE s, const GraphDefBuilder::Options& options) { \
return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \
Output Const(const Scope& scope, const Input::Initializer& val) {
if (!scope.ok()) return Output();
if (!val.status.ok()) {
scope.UpdateStatus(val.status);
return Output();
}
#define DEFINE_CONST_VECTOR(TYPE) \
Node* Const(gtl::ArraySlice<TYPE> v, \
const GraphDefBuilder::Options& options) { \
return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \
}
Node* ret;
Graph* graph = scope.graph();
const string unique_name = scope.GetUniqueNameForOp("Const");
auto builder = NodeBuilder(unique_name, "Const")
.Attr("value", val.tensor)
.Attr("dtype", val.tensor.dtype());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(graph, &ret));
#define DEFINE_CONST_TENSOR(TYPE, ...) \
Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
const GraphDefBuilder::Options& options) { \
if (options.HaveError()) return nullptr; \
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), \
options.op_registry()); \
const DataType dt = DataTypeToEnum<TYPE>::v(); \
if (t.size() == 1) { \
TensorProto proto; \
proto.set_dtype(dt); \
shape.AsProto(proto.mutable_tensor_shape()); \
__VA_ARGS__; \
node_builder.Attr("dtype", dt).Attr("value", proto); \
} else { \
Tensor tensor(dt, shape); \
if (tensor.NumElements() != static_cast<int64>(t.size())) { \
options.UpdateStatus(errors::InvalidArgument( \
t.size(), " values provided to Const() != ", tensor.NumElements(), \
" elements for shape ", shape.DebugString())); \
} else { \
std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data()); \
node_builder.Attr("dtype", dt).Attr("value", tensor); \
} \
} \
return options.FinalizeBuilder(&node_builder); \
}
if (!scope.ok()) return Output();
#define DEFINE_CONST_IMPL(TYPE, ...) \
DEFINE_CONST_SCALAR(TYPE) \
DEFINE_CONST_VECTOR(TYPE) \
DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__)
#define DEFINE_CONST(TYPE, FIELD) \
DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());)
DEFINE_CONST(float, float_val);
DEFINE_CONST(double, double_val);
DEFINE_CONST(int32, int_val);
DEFINE_CONST(uint8, int_val);
DEFINE_CONST(int16, int_val);
DEFINE_CONST(int8, int_val);
DEFINE_CONST(int64, int64_val);
DEFINE_CONST(bool, bool_val);
DEFINE_CONST_IMPL(Eigen::half, proto.add_half_val(t.begin()->x));
DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real());
proto.add_scomplex_val(t.begin()->imag()););
DEFINE_CONST_IMPL(complex128, proto.add_dcomplex_val(t.begin()->real());
proto.add_dcomplex_val(t.begin()->imag()););
Node* Const(StringPiece s, const GraphDefBuilder::Options& options) {
if (options.HaveError()) return nullptr;
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
options.op_registry());
TensorProto proto;
proto.set_dtype(DT_STRING);
TensorShape({}).AsProto(proto.mutable_tensor_shape());
proto.add_string_val(s.data(), s.size());
node_builder.Attr("dtype", DT_STRING).Attr("value", proto);
return options.FinalizeBuilder(&node_builder);
return Output(ret);
}
DEFINE_CONST_VECTOR(string)
DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());)
#undef DEFINE_CONST
#undef DEFINE_CONST_IMPL
#undef DEFINE_CONST_TENSOR
#undef DEFINE_CONST_VECTOR
#undef DEFINE_CONST_SCALAR
Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) {
if (options.HaveError()) return nullptr;
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
options.op_registry());
node_builder.Attr("dtype", t.dtype()).Attr("value", t);
return options.FinalizeBuilder(&node_builder);
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) {
if (!inp.status().ok()) {
scope.UpdateStatus(inp.status());
return NodeBuilder::NodeOut(inp.node(), inp.index());
}
if (inp.node()) {
return NodeBuilder::NodeOut(inp.node(), inp.index());
}
if (!inp.node_name().empty()) {
return NodeBuilder::NodeOut(inp.node_name(), inp.index(), inp.data_type());
}
auto transformed = Input{
Const(scope.NewSubScope("Const"), Input::Initializer(inp.tensor()))};
return NodeBuilder::NodeOut{transformed.node(), transformed.index()};
}
Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) {
if (options.HaveError()) return nullptr;
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
options.op_registry());
node_builder.Attr("dtype", proto.dtype()).Attr("value", proto);
return options.FinalizeBuilder(&node_builder);
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
const InputList& inp) {
std::vector<NodeBuilder::NodeOut> out;
for (const auto& i : inp) {
const auto node_out = AsNodeOut(scope, i);
if (!scope.ok()) {
return {};
}
out.push_back(node_out);
}
return out;
}
} // namespace ops

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,75 +13,53 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_
#define TENSORFLOW_CC_OPS_CONST_OP_H_
#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
#define THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
namespace ops {
// If a shape is specified, you may either provide the same number of values,
// or a single value and that value will be duplicated to fill out the Tensor.
#define DECLARE_CONST(TYPE) \
Node* Const(TYPE s, const GraphDefBuilder::Options& options); /* Scalar */ \
Node* Const(gtl::ArraySlice<TYPE> v, \
const GraphDefBuilder::Options& options); /* Vector */ \
Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
const GraphDefBuilder::Options& options); /* Tensor */ \
inline Node* Const(std::initializer_list<TYPE> v, /* Vector using {...} */ \
const GraphDefBuilder::Options& options) { \
return Const(gtl::ArraySlice<TYPE>(v), options); \
} \
inline Node* Const(std::initializer_list<TYPE> t, /* Tensor using {...} */ \
const TensorShape& shape, \
const GraphDefBuilder::Options& options) { \
return Const(gtl::ArraySlice<TYPE>(t), shape, options); \
Output Const(const Scope& scope, const Input::Initializer& val);
template <typename T>
Output Const(const Scope& scope, const Input::Initializer& val) {
if (!scope.ok()) return Output();
if (!val.status.ok()) {
scope.UpdateStatus(val.status);
return Output();
}
DECLARE_CONST(Eigen::half);
DECLARE_CONST(float);
DECLARE_CONST(double);
DECLARE_CONST(int32);
DECLARE_CONST(uint8);
DECLARE_CONST(int16);
DECLARE_CONST(int8);
DECLARE_CONST(complex64);
DECLARE_CONST(complex128);
DECLARE_CONST(int64);
DECLARE_CONST(bool);
#undef DECLARE_CONST
// String
Node* Const(StringPiece s, const GraphDefBuilder::Options& options);
Node* Const(gtl::ArraySlice<string> v, const GraphDefBuilder::Options& options);
Node* Const(gtl::ArraySlice<string> t, const TensorShape& shape,
const GraphDefBuilder::Options& options);
inline Node* Const(std::initializer_list<string> v,
const GraphDefBuilder::Options& options) {
return Const(gtl::ArraySlice<string>(v), options);
}
inline Node* Const(std::initializer_list<string> t, const TensorShape& shape,
const GraphDefBuilder::Options& options) {
return Const(gtl::ArraySlice<string>(t), shape, options);
typedef typename Input::Initializer::RealType<T>::type DstT;
if (val.tensor.NumElements() > 0) {
// TODO(keveman): Implement the in-situ cast.
scope.UpdateStatus(errors::Unimplemented(
"Explict cast of a non-empty tensor not implemented yet"));
return Output();
}
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
return Const(scope, Input::Initializer(t));
}
// A Tensor of any type.
Node* Const(const Tensor& t, const GraphDefBuilder::Options& options);
Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options);
template <class T>
Node* EmptyConst(const GraphDefBuilder::Options& options) {
return Const(gtl::ArraySlice<T>(), options);
template <typename T>
Output Const(const Scope& scope, const T& v, const TensorShape shape) {
return Const(scope, Input::Initializer(v, shape));
}
// TODO(josh11b): Support other types (e.g. quantized ints, float16).
template <typename T>
Output Const(const Scope& scope, const std::initializer_list<T>& v,
const TensorShape shape) {
return Const(scope, Input::Initializer(v, shape));
}
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
const InputList& inp);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_CC_OPS_CONST_OP_H_
#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_

View File

@ -0,0 +1,128 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
template <typename T>
void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
TensorShape shape) {
EXPECT_TRUE(n->IsConstant());
Tensor tensor;
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
EXPECT_EQ(tensor.dtype(), dtype);
test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape));
}
void ExpectTypeAndShape(const Node* n, DataType expected_dtype,
TensorShape expected_shape) {
EXPECT_TRUE(n->IsConstant());
Tensor tensor;
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
EXPECT_EQ(dtype, expected_dtype);
EXPECT_EQ(expected_shape, TensorShape(tensor.shape()));
}
} // namespace
TEST(ConstOpTest, Basic) {
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, 42.0f);
TF_EXPECT_OK(root.status());
EXPECT_EQ(c.op().output_type(0), DT_FLOAT);
ExpectNodeEqual<float>(c.node(), {42.0f}, {});
}
TEST(ConstOpTest, MultiDim) {
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, {{2.0}, {3.0}});
TF_CHECK_OK(root.status());
EXPECT_EQ(c.op().output_type(0), DT_DOUBLE);
ExpectNodeEqual<double>(c.node(), {2.0, 3.0}, {2, 1});
}
TEST(ConstOpTest, Empty) {
Scope root = Scope::NewRootScope();
auto c1 = ops::Const(root, {});
TF_CHECK_OK(root.status());
ExpectTypeAndShape(c1.node(), DT_FLOAT, {0});
auto c2 = ops::Const(root, {{}});
TF_CHECK_OK(root.status());
ExpectTypeAndShape(c2.node(), DT_FLOAT, {1, 0});
auto c3 = ops::Const(root, {{{}, {}}});
TF_CHECK_OK(root.status());
ExpectTypeAndShape(c3.node(), DT_FLOAT, {1, 2, 0});
auto c4 = ops::Const<int>(root, {{{}}});
TF_CHECK_OK(root.status());
ExpectTypeAndShape(c4.node(), DT_INT32, {1, 1, 0});
ops::Const(root, {{}, {{}}});
EXPECT_FALSE(root.status().ok());
}
TEST(ConstOpTest, WithExplicitShape) {
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, 42.0, {2, 2});
TF_CHECK_OK(root.status());
EXPECT_EQ(c.op().output_type(0), DT_DOUBLE);
ExpectNodeEqual<double>(c.node(), {42.0, 42.0, 42.0, 42.0}, {2, 2});
auto d = ops::Const(root, {"1", "2", "3", "4", "5", "6"}, {2, 3});
TF_CHECK_OK(root.status());
EXPECT_EQ(d.op().output_type(0), DT_STRING);
ExpectNodeEqual<string>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3});
}
TEST(ConstOpTest, InvalidInitializer) {
Scope root = Scope::NewRootScope();
ops::Const(root, {{2.0}, {"df"}});
EXPECT_FALSE(root.status().ok());
}
TEST(ConstOpTest, Names) {
Scope root = Scope::NewRootScope();
auto c = ops::Const(root, {{2.0}, {3.0}});
EXPECT_EQ(c.node()->name(), "Const");
auto c_1 = ops::Const(root, {{2.0}, {3.0}});
EXPECT_EQ(c_1.node()->name(), "Const_1");
auto x = ops::Const(root.WithOpName("x"), 1);
EXPECT_EQ(x.node()->name(), "x");
auto x_1 = ops::Const(root.WithOpName("x"), 1);
EXPECT_EQ(x_1.node()->name(), "x_1");
Scope child = root.NewSubScope("c");
auto c_y = ops::Const(child.WithOpName("y"), 1);
EXPECT_EQ(c_y.node()->name(), "c/y");
auto c_y_1 = ops::Const(child.WithOpName("y"), 1);
EXPECT_EQ(c_y_1.node()->name(), "c/y_1");
}
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// #include this file to get access to the standard set of C++ graph
// definition libraries.
#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_
#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_
#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
#define THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/candidate_sampling_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/control_flow_ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/io_ops.h"
@ -28,6 +27,7 @@ limitations under the License.
#include "tensorflow/cc/ops/logging_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/no_op.h"
#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/random_ops.h"
#include "tensorflow/cc/ops/sparse_ops.h"
@ -36,4 +36,4 @@ limitations under the License.
#include "tensorflow/cc/ops/training_ops.h"
#include "tensorflow/cc/ops/user_ops.h"
#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_
#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_

View File

@ -49,31 +49,33 @@ struct Options {
GraphDef CreateGraphDef() {
// TODO(jeff,opensource): This should really be a more interesting
// computation. Maybe turn this into an mnist model instead?
GraphDefBuilder b;
Scope root = Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
// Store rows [3, 2] and [-1, 0] in row major format.
Node* a = Const({3.f, 2.f, -1.f, 0.f}, {2, 2}, b.opts());
// x is from the feed.
Node* x = Const({0.f}, {2, 1}, b.opts().WithName("x"));
// a = [3 2; -1 0]
auto a = Const(root, {{3.f, 2.f}, {-1.f, 0.f}});
// y = A * x
Node* y = MatMul(a, x, b.opts().WithName("y"));
// x = [1.0; 1.0]
auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
// y = a * x
auto y = MatMul(root.WithOpName("y"), a, x);
// y2 = y.^2
Node* y2 = Square(y, b.opts());
auto y2 = Square(root, y);
// y2_sum = sum(y2)
Node* y2_sum = Sum(y2, Const(0, b.opts()), b.opts());
auto y2_sum = Sum(root, y2, 0);
// y_norm = sqrt(y2_sum)
Node* y_norm = Sqrt(y2_sum, b.opts());
auto y_norm = Sqrt(root, y2_sum);
// y_normalized = y ./ y_norm
Div(y, y_norm, b.opts().WithName("y_normalized"));
Div(root.WithOpName("y_normalized"), y, y_norm);
GraphDef def;
TF_CHECK_OK(b.ToGraphDef(&def));
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}

View File

@ -42,13 +42,12 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT);
test::FillFn<float>(&test_tensor2, [](int) -> float { return 3.0; });
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* node1 = ops::Const(test_tensor1, b.opts());
Node* node2 = ops::Const(test_tensor2, b.opts());
const string result_name = ops::MatMul(node1, node2, b.opts())->name();
auto root = Scope::NewRootScope().ExitOnError();
ops::Output m = ops::MatMul(root, test_tensor1, test_tensor2);
const string result_name = m.node()->name();
GraphDef graph_def;
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
string graph_def_serialized;
graph_def.SerializeToString(&graph_def_serialized);
TF_ASSERT_OK(

View File

@ -1449,6 +1449,7 @@ tf_cc_tests(
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/core/kernels:ops_util",
"//third_party/eigen3",
],
@ -1752,6 +1753,7 @@ tf_cc_test_gpu(
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:ops_util",
],

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/op.h"
@ -35,9 +34,9 @@ class GpuStreamUtilTest : public OpsTestBase {
};
TEST_F(GpuStreamUtilTest, BogusOpts) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
auto root = Scope::NewRootScope().ExitOnError();
Graph g(OpRegistry::Global());
TF_ASSERT_OK(b.ToGraph(&g));
root.ToGraph(&g);
std::unordered_map<int, int> node_to_stream_id;
gpu_stream_util::AssignStreamsOpts opts;
Status status;
@ -55,9 +54,9 @@ TEST_F(GpuStreamUtilTest, BogusOpts) {
}
TEST_F(GpuStreamUtilTest, EmptyGraph) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
auto root = Scope::NewRootScope().ExitOnError();
Graph g(OpRegistry::Global());
TF_ASSERT_OK(b.ToGraph(&g));
root.ToGraph(&g);
std::unordered_map<int, int> node_to_stream_id;
gpu_stream_util::AssignStreamsOpts opts;
TF_ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
@ -65,11 +64,10 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) {
}
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
auto root = Scope::NewRootScope().ExitOnError();
ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(b.ToGraph(&g));
TF_ASSERT_OK(root.ToGraph(&g));
std::unordered_map<int, int> node_to_stream_id;
gpu_stream_util::AssignStreamsOpts opts;
@ -85,11 +83,10 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
}
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
auto root = Scope::NewRootScope().ExitOnError();
ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(b.ToGraph(&g));
TF_ASSERT_OK(root.ToGraph(&g));
std::unordered_map<int, int> node_to_stream_id;
gpu_stream_util::AssignStreamsOpts opts;
@ -107,14 +104,13 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
}
TEST_F(GpuStreamUtilTest, StreamOverrides) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
ops::_Recv(DT_FLOAT, "input", "/cpu:0", 0, "/gpu:0",
b.opts().WithName("input"));
auto n = ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
ops::_Send(n, "output", "/gpu:0", 0, "/cpu:0", b.opts().WithName("output"));
auto root = Scope::NewRootScope().ExitOnError();
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
"/gpu:0");
ops::Output n = ops::MatMul(root, {}, {});
ops::_Send(root.WithOpName("output"), n, "output", "/gpu:0", 0, "/cpu:0");
Graph g(OpRegistry::Global());
TF_ASSERT_OK(b.ToGraph(&g));
TF_ASSERT_OK(root.ToGraph(&g));
// Perform stream assignment using a large number of streams, but with
// op types constrained to specific streams.

View File

@ -133,28 +133,41 @@ REGISTER_OP("Input").Output("o: float");
REGISTER_OP("BoolInput").Output("o: bool");
REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("Input", opts);
ops::Output ConstructOp(const Scope& scope, const string& op_type,
const gtl::ArraySlice<ops::Input>& inputs) {
if (!scope.ok()) return ops::Output();
const string unique_name = scope.GetUniqueNameForOp(op_type);
auto builder = NodeBuilder(unique_name, op_type);
for (auto const& input : inputs) {
builder.Input(ops::NodeOut(input.node(), input.index()));
}
scope.UpdateBuilder(&builder);
Node* ret;
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return ops::Output();
return ops::Output(ret);
}
Node* BoolInput(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("BoolInput", opts);
ops::Output Input(const Scope& scope) {
return ConstructOp(scope, "Input", {});
}
Node* Combine(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("Combine", a, b, opts);
ops::Output BoolInput(const Scope& scope) {
return ConstructOp(scope, "BoolInput", {});
}
ops::Output Combine(const Scope& scope, ops::Input a, ops::Input b) {
return ConstructOp(scope, "Combine", {a, b});
}
class GraphPartitionTest : public ::testing::Test {
protected:
GraphPartitionTest()
: in_(GraphDefBuilder::kFailImmediately),
builder_a_(GraphDefBuilder::kFailImmediately),
builder_b_(GraphDefBuilder::kFailImmediately),
a_opts_(builder_a_.opts().WithDevice("/job:a/replica:0/task:0/cpu:0")),
b_opts_(builder_b_.opts().WithDevice("/job:a/replica:0/task:0/cpu:1")) {
}
: in_(Scope::NewRootScope().ExitOnError()),
scope_a_(Scope::NewRootScope().ExitOnError().WithDevice(
"/job:a/replica:0/task:0/cpu:0")),
scope_b_(Scope::NewRootScope().ExitOnError().WithDevice(
"/job:a/replica:0/task:0/cpu:1")) {}
const GraphDef& ToGraphDef() {
in_.ToGraphDef(&in_graph_def_);
@ -163,187 +176,187 @@ class GraphPartitionTest : public ::testing::Test {
void ExpectMatchA() {
GraphDef graph_def;
builder_a_.ToGraphDef(&graph_def);
scope_a_.ToGraphDef(&graph_def);
string a = "/job:a/replica:0/task:0/cpu:0";
TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]);
}
void ExpectMatchB() {
GraphDef graph_def;
builder_b_.ToGraphDef(&graph_def);
scope_b_.ToGraphDef(&graph_def);
string b = "/job:a/replica:0/task:0/cpu:1";
TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]);
}
GraphDefBuilder in_;
Scope in_;
GraphDef in_graph_def_;
GraphDefBuilder builder_a_;
GraphDefBuilder builder_b_;
GraphDefBuilder::Options a_opts_;
GraphDefBuilder::Options b_opts_;
Scope scope_a_;
Scope scope_b_;
std::unordered_map<string, GraphDef> partitions_;
};
TEST_F(GraphPartitionTest, SingleDevice) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Combine(a1, a1, in_.opts().WithName("A2"));
auto a1 = Input(in_.WithOpName("A1"));
Combine(in_.WithOpName("A2"), a1, a1);
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(1, partitions_.size());
a1 = Input(a_opts_.WithName("A1"));
Combine(a1, a1, a_opts_.WithName("A2"));
a1 = Input(scope_a_.WithOpName("A1"));
Combine(scope_a_.WithOpName("A2"), a1, a1);
ExpectMatchA();
}
TEST_F(GraphPartitionTest, CrossDeviceData) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
Combine(a1, b1, in_.opts().WithName("B2"));
auto a1 = Input(in_.WithOpName("A1"));
auto b1 = Input(in_.WithOpName("B1"));
Combine(in_.WithOpName("B2"), a1, b1);
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
string a = "/job:a/replica:0/task:0/cpu:0";
string b = "/job:a/replica:0/task:0/cpu:1";
a1 = Input(a_opts_.WithName("A1"));
_Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
a1 = Input(scope_a_.WithOpName("A1"));
_Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
ExpectMatchA();
b1 = Input(b_opts_.WithName("B1"));
Node* recv =
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
Combine(recv, b1, b_opts_.WithName("B2"));
b1 = Input(scope_b_.WithOpName("B1"));
auto recv =
_Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
Combine(scope_b_.WithOpName("B2"), recv, b1);
ExpectMatchB();
}
TEST_F(GraphPartitionTest, CrossDeviceControl) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
Combine(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
auto a1 = Input(in_.WithOpName("A1"));
auto b1 = Input(in_.WithOpName("B1"));
Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
string a = "/job:a/replica:0/task:0/cpu:0";
string b = "/job:a/replica:0/task:0/cpu:1";
a1 = Input(a_opts_.WithName("A1"));
Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
_Send(c, "edge_3_A1", a, 82, b, a_opts_.WithName("A1/_1"));
a1 = Input(scope_a_.WithOpName("A1"));
auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
_Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
ExpectMatchA();
Node* recv =
_Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2"));
Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
b1 = Input(b_opts_.WithName("B1"));
Combine(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
auto recv =
_Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
b1 = Input(scope_b_.WithOpName("B1"));
Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
ExpectMatchB();
}
TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
Combine(a1, b1, in_.opts().WithName("B2"));
Combine(a1, a1, in_.opts().WithName("B3"));
auto a1 = Input(in_.WithOpName("A1"));
auto b1 = Input(in_.WithOpName("B1"));
Combine(in_.WithOpName("B2"), a1, b1);
Combine(in_.WithOpName("B3"), a1, a1);
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
string a = "/job:a/replica:0/task:0/cpu:0";
string b = "/job:a/replica:0/task:0/cpu:1";
a1 = Input(a_opts_.WithName("A1"));
_Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
a1 = Input(scope_a_.WithOpName("A1"));
_Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
ExpectMatchA();
Node* recv =
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
b1 = Input(b_opts_.WithName("B1"));
Combine(recv, b1, b_opts_.WithName("B2"));
Combine(recv, recv, b_opts_.WithName("B3"));
auto recv =
_Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
b1 = Input(scope_b_.WithOpName("B1"));
Combine(scope_b_.WithOpName("B2"), recv, b1);
Combine(scope_b_.WithOpName("B3"), recv, recv);
ExpectMatchB();
}
TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
Combine(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
Input(in_.opts().WithName("B3").WithControlInput(a1));
auto a1 = Input(in_.WithOpName("A1"));
auto b1 = Input(in_.WithOpName("B1"));
Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
Input(in_.WithOpName("B3").WithControlDependencies(a1));
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
string a = "/job:a/replica:0/task:0/cpu:0";
string b = "/job:a/replica:0/task:0/cpu:1";
a1 = Input(a_opts_.WithName("A1"));
Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
_Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
a1 = Input(scope_a_.WithOpName("A1"));
auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
_Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
ExpectMatchA();
Node* recv =
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
b1 = Input(b_opts_.WithName("B1"));
Combine(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
Input(b_opts_.WithName("B3").WithControlInput(id));
auto recv =
_Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
b1 = Input(scope_b_.WithOpName("B1"));
Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
Input(scope_b_.WithOpName("B3").WithControlDependencies(id));
ExpectMatchB();
}
TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = Input(in_.opts().WithName("A1"));
Node* b1 = Input(in_.opts().WithName("B1"));
Combine(a1, b1, in_.opts().WithName("B2"));
Input(in_.opts().WithName("B3").WithControlInput(a1));
auto a1 = Input(in_.WithOpName("A1"));
auto b1 = Input(in_.WithOpName("B1"));
Combine(in_.WithOpName("B2"), a1, b1);
Input(in_.WithOpName("B3").WithControlDependencies(a1));
Partition(ToGraphDef(), &partitions_);
EXPECT_EQ(2, partitions_.size());
string a = "/job:a/replica:0/task:0/cpu:0";
string b = "/job:a/replica:0/task:0/cpu:1";
a1 = Input(a_opts_.WithName("A1"));
Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
a1 = Input(scope_a_.WithOpName("A1"));
auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
// NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
// use A1/_0 -> A1/_4 as the control as a minor optimization.
_Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
_Send(a1, "edge_2_A1", a, 82, b, a_opts_.WithName("A1/_4"));
_Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
_Send(scope_a_.WithOpName("A1/_4"), a1, "edge_2_A1", a, 82, b);
ExpectMatchA();
Node* recv1 =
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
Node* id1 = Identity(recv1, b_opts_.WithName("A1/_3"));
Node* recv2 =
_Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5"));
b1 = Input(b_opts_.WithName("B1"));
Combine(recv2, b1, b_opts_.WithName("B2"));
Input(b_opts_.WithName("B3").WithControlInput(id1));
auto recv1 =
_Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
auto id1 = Identity(scope_b_.WithOpName("A1/_3"), recv1);
auto recv2 =
_Recv(scope_b_.WithOpName("A1/_5"), DT_FLOAT, "edge_2_A1", a, 82, b);
b1 = Input(scope_b_.WithOpName("B1"));
Combine(scope_b_.WithOpName("B2"), recv2, b1);
Input(scope_b_.WithOpName("B3").WithControlDependencies(id1));
ExpectMatchB();
}
TEST_F(GraphPartitionTest, CrossDeviceLoop) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = BoolInput(in_.opts().WithName("A1"));
Node* a2 = Enter(a1, "foo", in_.opts().WithName("A2"));
Node* a3 = Merge({a2, {"A5", 0, DT_BOOL}}, in_.opts().WithName("A3"));
LoopCond(a3, in_.opts().WithName("A4"));
Node* b1 = Identity(a3, in_.opts().WithName("B1"));
NextIteration(b1, in_.opts().WithName("A5"));
auto a1 = BoolInput(in_.WithOpName("A1"));
auto a2 = Enter(in_.WithOpName("A2"), a1, "foo");
auto a3 =
Merge(in_.WithOpName("A3"), {a2, ops::Input("A5", 0, DT_BOOL)}).output;
LoopCond(in_.WithOpName("A4"), a3);
auto b1 = Identity(in_.WithOpName("B1"), a3);
NextIteration(in_.WithOpName("A5"), b1);
CheckLoopConstruction(ToGraphDef());
}
TEST_F(GraphPartitionTest, CrossDeviceLoop1) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = BoolInput(in_.opts().WithName("A1"));
Node* a2 = Enter(a1, "foo", in_.opts().WithName("B2"));
Node* a3 = Merge({a2, {"B5", 0, DT_BOOL}}, in_.opts().WithName("A3"));
LoopCond(a3, in_.opts().WithName("A4"));
Node* b1 = Identity(a3, in_.opts().WithName("B1"));
NextIteration(b1, in_.opts().WithName("B5"));
auto a1 = BoolInput(in_.WithOpName("A1"));
auto a2 = Enter(in_.WithOpName("B2"), a1, "foo");
auto a3 =
Merge(in_.WithOpName("A3"), {a2, ops::Input("B5", 0, DT_BOOL)}).output;
LoopCond(in_.WithOpName("A4"), a3);
auto b1 = Identity(in_.WithOpName("B1"), a3);
NextIteration(in_.WithOpName("B5"), b1);
std::unordered_map<string, GraphDef> partitions;
Partition(ToGraphDef(), &partitions);

View File

@ -91,14 +91,14 @@ struct ImmutableConstantOpTest {};
TEST(ImmutableConstantOpTest, Simple) {
const TensorShape kTestTensorShape({4, 1});
const TensorShape kTestTensorShapeT({1, 4});
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* node1 =
ops::ImmutableConst(DT_FLOAT, kTestTensorShape, "test://2", b.opts());
Node* node2 =
ops::ImmutableConst(DT_FLOAT, kTestTensorShapeT, "test://3", b.opts());
Node* result = ops::MatMul(node1, node2, b.opts());
auto root = Scope::NewRootScope().ExitOnError();
auto node1 =
ops::ImmutableConst(root, DT_FLOAT, kTestTensorShape, "test://2");
auto node2 =
ops::ImmutableConst(root, DT_FLOAT, kTestTensorShapeT, "test://3");
auto result = ops::MatMul(root, node1, node2);
GraphDef graph_def;
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
SessionOptions session_options;
session_options.env = Env::Default();
session_options.config.mutable_graph_options()
@ -108,7 +108,7 @@ TEST(ImmutableConstantOpTest, Simple) {
ASSERT_TRUE(session != nullptr) << "Failed to create session";
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run({}, {result->name() + ":0"}, {}, &outputs));
TF_ASSERT_OK(session->Run({}, {result.node()->name() + ":0"}, {}, &outputs));
ASSERT_EQ(outputs.size(), 1);
EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f);
EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f);
@ -122,14 +122,14 @@ TEST(ImmutableConstantOpTest, Simple) {
TEST(ImmutableConstantOpTest, ExecutionError) {
const TensorShape kBadTensorShape({40, 100});
const TensorShape kTestTensorShapeT({1, 4});
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* node1 =
ops::ImmutableConst(DT_FLOAT, kBadTensorShape, "test://2", b.opts());
Node* node2 =
ops::ImmutableConst(DT_FLOAT, kTestTensorShapeT, "test://3", b.opts());
Node* result = ops::MatMul(node1, node2, b.opts());
auto root = Scope::NewRootScope().ExitOnError();
auto node1 = ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test://2");
auto node2 =
ops::ImmutableConst(root, DT_FLOAT, kTestTensorShapeT, "test://3");
auto result = ops::MatMul(root, node1, node2);
GraphDef graph_def;
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
SessionOptions session_options;
session_options.env = Env::Default();
std::unique_ptr<Session> session(NewSession(session_options));
@ -137,8 +137,9 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
std::vector<Tensor> outputs;
// Check that the run returned error.
EXPECT_EQ(session->Run({}, {result->name() + ":0"}, {}, &outputs).code(),
error::INTERNAL);
EXPECT_EQ(
session->Run({}, {result.node()->name() + ":0"}, {}, &outputs).code(),
error::INTERNAL);
}
Status CreateTempFile(Env* env, float value, uint64 size, string* filename) {
@ -158,19 +159,18 @@ Status CreateTempFile(Env* env, float value, uint64 size, string* filename) {
TEST(ImmutableConstantOpTest, FromFile) {
const TensorShape kFileTensorShape({1000, 1});
Env* env = Env::Default();
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
auto root = Scope::NewRootScope().ExitOnError();
string two_file, three_file;
TF_ASSERT_OK(CreateTempFile(env, 2.0f, 1000, &two_file));
TF_ASSERT_OK(CreateTempFile(env, 3.0f, 1000, &three_file));
Node* node1 =
ops::ImmutableConst(DT_FLOAT, kFileTensorShape, two_file, b.opts());
Node* node2 =
ops::ImmutableConst(DT_FLOAT, kFileTensorShape, three_file, b.opts());
Node* result =
ops::MatMul(node1, node2, b.opts().WithAttr("transpose_b", true));
auto node1 = ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, two_file);
auto node2 =
ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, three_file);
auto result = ops::MatMul(root, node1, node2, ops::MatMul::TransposeB(true));
GraphDef graph_def;
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
SessionOptions session_options;
session_options.config.mutable_graph_options()
->mutable_optimizer_options()
@ -179,7 +179,7 @@ TEST(ImmutableConstantOpTest, FromFile) {
ASSERT_TRUE(session != nullptr) << "Failed to create session";
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run({}, {result->name() + ":0"}, {}, &outputs));
TF_ASSERT_OK(session->Run({}, {result.node()->name() + ":0"}, {}, &outputs));
ASSERT_EQ(outputs.size(), 1);
EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f);
EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f);

View File

@ -1047,7 +1047,7 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
int depth, int kernel_rows, int kernel_cols,
int stride, Padding padding, int num_threads,
bool use_gpu, const string& label) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
auto root = Scope::NewRootScope().ExitOnError();
int64 out_height, out_width, pad_rows, pad_cols;
TF_CHECK_OK(GetWindowedOutputSize(rows, kernel_rows, stride, padding,
@ -1057,27 +1057,25 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
Tensor input_data(DT_FLOAT, TensorShape({batch_size, rows, cols, depth}));
input_data.flat<float>().setRandom();
Node* input_data_node = ops::Const(input_data, b.opts());
Tensor output_data(DT_FLOAT,
TensorShape({batch_size, out_height, out_width, depth}));
output_data.flat<float>().setRandom();
Node* output_data_node = ops::Const(output_data, b.opts());
Tensor output_diff(DT_FLOAT,
TensorShape({batch_size, out_height, out_width, depth}));
output_diff.flat<float>().setRandom();
Node* output_diff_node = ops::Const(output_diff, b.opts());
CHECK_EQ(kernel_rows, kernel_cols);
ops::MaxPoolGrad(input_data_node, output_data_node, output_diff_node,
ops::MaxPoolGrad(root, input_data, output_data, output_diff,
{1, kernel_rows, kernel_cols, 1} /* ksize */,
{1, stride, stride, 1} /* stride */,
padding == VALID ? "VALID" : "SAME", b.opts());
Graph* g = new Graph(OpRegistry::Global());
TF_CHECK_OK(b.ToGraph(g));
padding == VALID ? "VALID" : "SAME");
TF_CHECK_OK(root.status());
Graph g(OpRegistry::Global());
root.ToGraph(&g);
string device = use_gpu ? "gpu" : "cpu";
test::Benchmark(device, g).Run(iters);
test::Benchmark(device, &g).Run(iters);
testing::ItemsProcessed(batch_size * rows * cols * depth * iters);
testing::SetLabel(label);

View File

@ -669,12 +669,9 @@ static void BM_LargeTensorWrite(int iters, int num_elements) {
// Builds the graph.
const string temp_filename =
io::JoinPath(testing::TmpDir(), "benchmark_checkpoint");
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* filename = ops::Const(test::AsScalar<string>(temp_filename), b.opts());
Node* tensor_names =
ops::Const(test::AsTensor<string>({"my_tensor"}), b.opts());
Node* tensors = ops::Const(tensor, b.opts());
ops::Save(filename, tensor_names, {tensors}, b.opts());
auto root = Scope::NewRootScope().ExitOnError();
const string tensor_name = "my_tensor";
ops::Save(root, temp_filename, {tensor_name}, {{tensor}});
// Disables optimizations.
SessionOptions session_options;
@ -682,13 +679,14 @@ static void BM_LargeTensorWrite(int iters, int num_elements) {
->mutable_optimizer_options()
->set_opt_level(tensorflow::OptimizerOptions_Level_L0);
Graph* g = new Graph(OpRegistry::Global());
TF_CHECK_OK(b.ToGraph(g));
TF_CHECK_OK(root.status());
Graph g(OpRegistry::Global());
root.ToGraph(&g);
VLOG(1) << "Save op's output path: " << temp_filename;
VLOG(1) << "# nodes in Graph: " << g->num_nodes();
VLOG(1) << "# nodes in Graph: " << g.num_nodes();
testing::StartTiming();
test::Benchmark("cpu", g, &session_options).Run(iters);
test::Benchmark("cpu", &g, &session_options).Run(iters);
}
BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);

View File

@ -87,50 +87,44 @@ Status ReadTensorFromImageFile(string file_name, const int input_height,
const int input_width, const float input_mean,
const float input_std,
std::vector<Tensor>* out_tensors) {
tensorflow::GraphDefBuilder b;
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
string input_name = "file_reader";
string output_name = "normalized";
tensorflow::Node* file_reader =
tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()),
b.opts().WithName(input_name));
auto file_reader = ReadFile(root.WithOpName(input_name), file_name);
// Now try to figure out what kind of file it is and decode it.
const int wanted_channels = 3;
tensorflow::Node* image_reader;
Output image_reader;
if (tensorflow::StringPiece(file_name).ends_with(".png")) {
image_reader = tensorflow::ops::DecodePng(
file_reader,
b.opts().WithAttr("channels", wanted_channels).WithName("png_reader"));
image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
DecodePng::Channels(wanted_channels));
} else {
// Assume if it's not a PNG then it must be a JPEG.
image_reader = tensorflow::ops::DecodeJpeg(
file_reader,
b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader"));
image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
DecodeJpeg::Channels(wanted_channels));
}
// Now cast the image data to float so we can do normal math on it.
tensorflow::Node* float_caster = tensorflow::ops::Cast(
image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster"));
auto float_caster =
Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
// The convention for image ops in TensorFlow is that all images are expected
// to be in batches, so that they're four-dimensional arrays with indices of
// [batch, height, width, channel]. Because we only have a single image, we
// have to add a batch dimension of 1 to the start with ExpandDims().
tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims(
float_caster, tensorflow::ops::Const(0, b.opts()), b.opts());
auto dims_expander = ExpandDims(root, float_caster, 0);
// Bilinearly resize the image to fit the required dimensions.
tensorflow::Node* resized = tensorflow::ops::ResizeBilinear(
dims_expander, tensorflow::ops::Const({input_height, input_width},
b.opts().WithName("size")),
b.opts());
auto resized = ResizeBilinear(
root, dims_expander,
Const(root.WithOpName("size"), {input_height, input_width}));
// Subtract the mean and divide by the scale.
tensorflow::ops::Div(
tensorflow::ops::Sub(
resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()),
tensorflow::ops::Const({input_std}, b.opts()),
b.opts().WithName(output_name));
Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
{input_std});
// This runs the GraphDef network definition that we've just constructed, and
// returns the results in the output tensor.
tensorflow::GraphDef graph;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
std::unique_ptr<tensorflow::Session> session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_RETURN_IF_ERROR(session->Create(graph));
@ -161,15 +155,16 @@ Status LoadGraph(string graph_file_name,
// their positions in the tensor, which correspond to categories.
Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
Tensor* indices, Tensor* scores) {
tensorflow::GraphDefBuilder b;
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
string output_name = "top_k";
tensorflow::ops::TopKV2(tensorflow::ops::Const(outputs[0], b.opts()),
tensorflow::ops::Const(how_many_labels, b.opts()),
b.opts().WithName(output_name));
TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels);
// This runs the GraphDef network definition that we've just constructed, and
// returns the results in the output tensors.
tensorflow::GraphDef graph;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
std::unique_ptr<tensorflow::Session> session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_RETURN_IF_ERROR(session->Create(graph));

View File

@ -175,6 +175,11 @@ def tf_gen_op_wrappers_cc(name,
other_srcs=[],
other_hdrs=[],
pkg="",
deps=[
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:const_op",
],
op_gen="//tensorflow/cc:cc_op_gen_main"):
subsrcs = other_srcs
subhdrs = other_hdrs
@ -186,7 +191,12 @@ def tf_gen_op_wrappers_cc(name,
native.cc_library(name=name,
srcs=subsrcs,
hdrs=subhdrs,
deps=["//tensorflow/core:core_cpu"],
deps=deps + [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
copts=tf_copts(),
alwayslink=1,)

View File

@ -39,16 +39,15 @@ TEST(BenchmarkModelTest, InitializeAndRun) {
Tensor constant_tensor(DT_FLOAT, constant_shape);
test::FillFn<float>(&constant_tensor, [](int) -> float { return 3.0; });
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* placeholder =
ops::Placeholder(DT_FLOAT, b.opts().WithAttr("shape", input_shape));
const string input_name = placeholder->name();
Node* constant = ops::Const(constant_tensor, b.opts());
const string output_name =
ops::MatMul(placeholder, constant, b.opts())->name();
auto root = Scope::NewRootScope().ExitOnError();
auto placeholder =
ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input_shape));
const string input_name = placeholder.node()->name();
auto m = ops::MatMul(root, placeholder, constant_tensor);
const string output_name = m.node()->name();
GraphDef graph_def;
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
string graph_def_serialized;
graph_def.SerializeToString(&graph_def_serialized);
TF_ASSERT_OK(