Improvements to the C++ graph building API.
TESTED: - passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/ Change: 127585603
This commit is contained in:
parent
194efde518
commit
25ac3dabfa
RELEASE.md
tensorflow
cc
BUILD
framework
cc_op_gen.cccc_op_gen.hcc_op_gen_main.cccc_ops_test.ccops.ccops.hscope.ccscope.hscope_test.cctest_op.cc
ops
tutorials
contrib/util
core
examples/label_image
tensorflow.bzltools/benchmark
@ -4,6 +4,9 @@
|
||||
* Connectionist Temporal Classification ops are now "official" (see, e.g.,
|
||||
`tf.nn.ctc_loss`)
|
||||
* Preliminary graph-construction C API, for use by language bindings.
|
||||
* Major revision to the graph-construction C++ API. Scoping mechanism to make op
|
||||
naming, specifying control dependencies etc. more consistent. C++ values can
|
||||
be used directly as operands, making op construction more concise.
|
||||
|
||||
## Breaking Changes to the API
|
||||
* `env.h` replaces use of `New*File()` functions to use `std::unique_ptr`
|
||||
|
@ -2,29 +2,77 @@
|
||||
# TensorFlow is a computational framework, primarily for use in machine
|
||||
# learning applications.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrappers_cc")
|
||||
|
||||
cc_library(
|
||||
name = "cc_op_gen_main",
|
||||
srcs = [
|
||||
"ops/cc_op_gen.cc",
|
||||
"ops/cc_op_gen_main.cc",
|
||||
],
|
||||
hdrs = ["ops/cc_op_gen.h"],
|
||||
copts = tf_copts(),
|
||||
name = "ops",
|
||||
srcs = ["framework/ops.cc"],
|
||||
hdrs = ["framework/ops.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scope",
|
||||
srcs = ["framework/scope.cc"],
|
||||
hdrs = ["framework/scope.h"],
|
||||
deps = [
|
||||
":ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "framework/scope_test.cc",
|
||||
deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "const_op",
|
||||
srcs = ["ops/const_op.cc"],
|
||||
hdrs = ["ops/const_op.h"],
|
||||
deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
# Generates a library that contains C++ wrappers for ops.
|
||||
tf_cc_test(
|
||||
name = "ops/const_op_test.cc",
|
||||
deps = [
|
||||
":const_op",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
op_lib_names = [
|
||||
@ -41,7 +89,6 @@ tf_gen_op_wrappers_cc(
|
||||
"no_op",
|
||||
"parsing_ops",
|
||||
"random_ops",
|
||||
"sendrecv_ops",
|
||||
"sparse_ops",
|
||||
"state_ops",
|
||||
"string_ops",
|
||||
@ -52,12 +99,63 @@ tf_gen_op_wrappers_cc(
|
||||
"ops/const_op.h",
|
||||
"ops/standard_ops.h",
|
||||
],
|
||||
other_srcs = [
|
||||
"ops/const_op.cc",
|
||||
] + glob(["ops/*_grad.cc"]),
|
||||
pkg = "//tensorflow/core",
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "framework/cc_ops_test.cc",
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":test_op",
|
||||
":test_op_op_lib",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "sendrecv_ops",
|
||||
op_lib_names = [
|
||||
"sendrecv_ops",
|
||||
],
|
||||
pkg = "//tensorflow/core",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cc_op_gen_main",
|
||||
srcs = [
|
||||
"framework/cc_op_gen.cc",
|
||||
"framework/cc_op_gen.h",
|
||||
"framework/cc_op_gen_main.cc",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_op_op_lib",
|
||||
srcs = ["framework/test_op.cc"],
|
||||
linkstatic = 1,
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "test_op",
|
||||
op_lib_names = [
|
||||
"test_op",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "tutorials_example_trainer",
|
||||
srcs = ["tutorials/example_trainer.cc"],
|
||||
@ -69,6 +167,10 @@ cc_binary(
|
||||
deps = [
|
||||
":cc_ops",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
798
tensorflow/cc/framework/cc_op_gen.cc
Normal file
798
tensorflow/cc/framework/cc_op_gen.cc
Normal file
@ -0,0 +1,798 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/framework/types.pb_text.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
const int kRightMargin = 79;
|
||||
|
||||
// Converts:
|
||||
// bazel-out/.../genfiles/(external/YYY/)?XX
|
||||
// to: XX.
|
||||
string GetPath(const std::string& dot_h_fname) {
|
||||
auto pos = dot_h_fname.find("/genfiles/");
|
||||
string result = dot_h_fname;
|
||||
if (pos != string::npos) {
|
||||
// - 1 account for the terminating null character (\0) in "/genfiles/".
|
||||
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
|
||||
}
|
||||
if (result.size() > sizeof("external/") &&
|
||||
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
|
||||
result = result.substr(sizeof("external/") - 1);
|
||||
pos = result.find("/");
|
||||
if (pos != string::npos) {
|
||||
result = result.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Converts:
|
||||
// cc/ops/gen_foo_ops.h
|
||||
// to:
|
||||
// CC_OPS_GEN_FOO_OPS_H_
|
||||
string ToGuard(const std::string& path) {
|
||||
string guard;
|
||||
guard.reserve(path.size() + 1); // + 1 -> trailing _
|
||||
for (const char c : path) {
|
||||
if (c >= 'A' && c <= 'Z') {
|
||||
guard += c;
|
||||
} else if (c >= 'a' && c <= 'z') {
|
||||
guard += c + 'A' - 'a';
|
||||
} else {
|
||||
guard += '_';
|
||||
}
|
||||
}
|
||||
guard += '_';
|
||||
return guard;
|
||||
}
|
||||
|
||||
// Change: Into:
|
||||
// ABC // ABC
|
||||
// //
|
||||
// DEF // DEF
|
||||
string MakeComment(StringPiece text, StringPiece indent) {
|
||||
string ret;
|
||||
while (!text.empty()) {
|
||||
int last_non_space = -1;
|
||||
int newline;
|
||||
for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
|
||||
if (text[newline] == '\n') break;
|
||||
if (text[newline] != ' ') last_non_space = newline;
|
||||
}
|
||||
if (last_non_space == -1) {
|
||||
strings::StrAppend(&ret, indent, "//\n");
|
||||
} else {
|
||||
strings::StrAppend(&ret, indent, "// ",
|
||||
text.substr(0, last_non_space + 1), "\n");
|
||||
}
|
||||
text.remove_prefix(newline + 1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
string PrintString(const string& str) {
|
||||
return strings::StrCat("\"", str_util::CEscape(str), "\"");
|
||||
}
|
||||
|
||||
string PrintTensorShape(const TensorShape& shape) {
|
||||
string ret = "{";
|
||||
for (int d = 0; d < shape.dims(); ++d) {
|
||||
if (d > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret, shape.dim_size(d));
|
||||
}
|
||||
strings::StrAppend(&ret, "}");
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
string PrintArray(int64 num_elts, const T* array) {
|
||||
string ret;
|
||||
for (int64 i = 0; i < num_elts; ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret, array[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
string PrintTensor(const TensorProto& tensor_proto) {
|
||||
Tensor t(tensor_proto.dtype());
|
||||
CHECK(t.FromProto(tensor_proto));
|
||||
const int64 num_elts = t.NumElements();
|
||||
switch (t.dtype()) {
|
||||
case DT_FLOAT:
|
||||
return PrintArray(num_elts, t.flat<float>().data());
|
||||
case DT_DOUBLE:
|
||||
return PrintArray(num_elts, t.flat<double>().data());
|
||||
case DT_INT32:
|
||||
return PrintArray(num_elts, t.flat<int32>().data());
|
||||
case DT_UINT8:
|
||||
case DT_QUINT8:
|
||||
return PrintArray(num_elts, t.flat<uint8>().data());
|
||||
case DT_UINT16:
|
||||
case DT_QUINT16:
|
||||
return PrintArray(num_elts, t.flat<uint16>().data());
|
||||
case DT_INT16:
|
||||
case DT_QINT16:
|
||||
return PrintArray(num_elts, t.flat<int16>().data());
|
||||
case DT_INT8:
|
||||
case DT_QINT8:
|
||||
return PrintArray(num_elts, t.flat<int8>().data());
|
||||
case DT_INT64:
|
||||
return PrintArray(num_elts, t.flat<int64>().data());
|
||||
case DT_BOOL:
|
||||
return PrintArray(num_elts, t.flat<bool>().data());
|
||||
case DT_STRING: {
|
||||
string ret;
|
||||
for (int64 i = 0; i < num_elts; ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, " ");
|
||||
strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Not handling type " << EnumName_DataType(t.dtype());
|
||||
return string();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string PrintAttrValue(string op, const AttrValue& attr_value) {
|
||||
switch (attr_value.value_case()) {
|
||||
case AttrValue::kS:
|
||||
return PrintString(attr_value.s());
|
||||
case AttrValue::kI:
|
||||
return strings::StrCat(attr_value.i());
|
||||
case AttrValue::kF:
|
||||
return strings::StrCat(attr_value.f());
|
||||
case AttrValue::kB:
|
||||
return attr_value.b() ? "true" : "false";
|
||||
case AttrValue::kType:
|
||||
return EnumName_DataType(attr_value.type());
|
||||
case AttrValue::kShape:
|
||||
return PrintTensorShape(TensorShape(attr_value.shape()));
|
||||
case AttrValue::kTensor:
|
||||
return strings::StrCat(
|
||||
"Input::Initializer(", "{", PrintTensor(attr_value.tensor()), "}, ",
|
||||
PrintTensorShape(TensorShape(attr_value.tensor().tensor_shape())),
|
||||
").AsTensorProto()");
|
||||
case AttrValue::kList: {
|
||||
string ret = "{";
|
||||
if (attr_value.list().s_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().s_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret, PrintString(attr_value.list().s(i)));
|
||||
}
|
||||
} else if (attr_value.list().i_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().i_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret, attr_value.list().i(i));
|
||||
}
|
||||
} else if (attr_value.list().f_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().f_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret, attr_value.list().f(i));
|
||||
}
|
||||
} else if (attr_value.list().b_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().b_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
|
||||
}
|
||||
} else if (attr_value.list().type_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().type_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(&ret,
|
||||
EnumName_DataType(attr_value.list().type(i)));
|
||||
}
|
||||
} else if (attr_value.list().shape_size() > 0) {
|
||||
for (int i = 0; i < attr_value.list().shape_size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&ret, ", ");
|
||||
strings::StrAppend(
|
||||
&ret, PrintTensorShape(TensorShape(attr_value.list().shape(i))));
|
||||
}
|
||||
}
|
||||
strings::StrAppend(&ret, "}");
|
||||
return ret;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported Attr type: " << op << " "
|
||||
<< attr_value.value_case();
|
||||
}
|
||||
return "<Unknown AttrValue type>"; // Prevent missing return warning
|
||||
}
|
||||
|
||||
string ToCamelCase(const string& str) {
|
||||
string result;
|
||||
const char joiner = '_';
|
||||
int i = 0;
|
||||
bool cap = true;
|
||||
while (i < str.size()) {
|
||||
const char c = str[i++];
|
||||
if (c == joiner) {
|
||||
cap = true;
|
||||
} else if (cap) {
|
||||
result += toupper(c);
|
||||
cap = false;
|
||||
} else {
|
||||
result += c;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a <string, bool> pair. The string is the C++ type name to be used for
|
||||
// attr_type when defining an object of that type. The bool is a flag to
|
||||
// indicate whether to treat the type as const when accepting the C++ type as an
|
||||
// argument to a function.
|
||||
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
|
||||
static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
|
||||
StringPiece::Hasher>
|
||||
attr_type_map{
|
||||
{"string", {"StringPiece", false}},
|
||||
{"list(string)", {"gtl::ArraySlice<string>", true}},
|
||||
{"int", {"int64", false}},
|
||||
{"list(int)", {"gtl::ArraySlice<int>", true}},
|
||||
{"float", {"float", false}},
|
||||
{"list(float)", {"gtl::ArraySlice<float>", true}},
|
||||
{"bool", {"bool", false}},
|
||||
{"list(bool)", {"gtl::ArraySlice<bool>", true}},
|
||||
{"type", {"DataType", false}},
|
||||
{"list(type)", {"DataTypeSlice", true}},
|
||||
{"shape", {"TensorShape", false}},
|
||||
{"list(shape)", {"gtl::ArraySlice<TensorShape>", true}},
|
||||
{"tensor", {"TensorProto", true}},
|
||||
{"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}},
|
||||
{"func", {"NameAttrList", true}},
|
||||
};
|
||||
|
||||
auto entry = attr_type_map.find(attr_type);
|
||||
if (entry == attr_type_map.end()) {
|
||||
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
|
||||
return {"", false};
|
||||
}
|
||||
return entry->second;
|
||||
}
|
||||
|
||||
bool IsCPPKeyword(StringPiece name) {
|
||||
static const std::unordered_set<StringPiece, StringPiece::Hasher>
|
||||
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
|
||||
kCPPReserved{
|
||||
"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel",
|
||||
"atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool",
|
||||
"break", "case", "catch", "char", "char16_t", "char32_t", "class",
|
||||
"compl", "concept", "const", "const_cast", "constexpr", "continue",
|
||||
"decltype", "default", "delete", "do", "double", "dynamic_cast",
|
||||
"else", "enum", "explicit", "export", "extern", "false", "final",
|
||||
"float", "for", "friend", "goto", "if", "import", "inline", "int",
|
||||
"long", "module", "mutable", "namespace", "new", "noexcept", "not",
|
||||
"not_eq", "nullptr", "operator", "or", "or_eq", "override", "private",
|
||||
"protected", "public", "register", "reinterpret_cast", "requires",
|
||||
"return", "short", "signed", "sizeof", "static", "static_assert",
|
||||
"static_cast", "struct", "switch", "synchronized", "template", "this",
|
||||
"thread_local", "throw", "true", "try", "typedef", "typeid",
|
||||
"typename", "union", "unsigned", "using", "virtual", "void",
|
||||
"volatile", "wchar_t", "while", "xor", "xor_eq",
|
||||
|
||||
// The following are not C++ keywords, but names of local variables
|
||||
// and parameters used in the op constructor. Treating them as
|
||||
// keywords, so that other parameter names don't conflict with these.
|
||||
"builder", "node", "ret", "scope", "unique_name",
|
||||
};
|
||||
return kCPPReserved.count(name) > 0;
|
||||
}
|
||||
|
||||
string AvoidCPPKeywords(StringPiece name) {
|
||||
if (IsCPPKeyword(name)) {
|
||||
return strings::StrCat(name, "_");
|
||||
}
|
||||
return name.ToString();
|
||||
}
|
||||
|
||||
void InferArgAttributes(const OpDef::ArgDef& arg,
|
||||
std::unordered_map<string, string>* inferred_attrs) {
|
||||
if (!arg.type_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(inferred_attrs, arg.type_attr(), arg.name());
|
||||
} else if (!arg.type_list_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(inferred_attrs, arg.type_list_attr(), arg.name());
|
||||
}
|
||||
if (!arg.number_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(inferred_attrs, arg.number_attr(), arg.name());
|
||||
}
|
||||
}
|
||||
|
||||
void InferOpAttributes(
|
||||
const OpDef& op_def,
|
||||
std::unordered_map<string, string>* inferred_input_attrs) {
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
InferArgAttributes(arg, inferred_input_attrs);
|
||||
}
|
||||
}
|
||||
|
||||
bool ArgIsList(const OpDef::ArgDef& arg) {
|
||||
return !arg.type_list_attr().empty() || !arg.number_attr().empty();
|
||||
}
|
||||
|
||||
bool HasOptionalAttrs(
|
||||
const OpDef& op_def,
|
||||
const std::unordered_map<string, string>& inferred_input_attrs) {
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
if ((inferred_input_attrs.find(attr.name()) ==
|
||||
inferred_input_attrs.end()) &&
|
||||
attr.has_default_value()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
struct OpInfo {
|
||||
explicit OpInfo(const OpDef& op_def);
|
||||
string GetOpAttrStruct() const;
|
||||
string GetConstructorDecl(StringPiece op_name_prefix,
|
||||
bool include_attr) const;
|
||||
void WriteClassDecl(WritableFile* h) const;
|
||||
void GetOutput(string* out) const;
|
||||
string GetConstructorBody() const;
|
||||
void WriteClassDef(WritableFile* cc) const;
|
||||
|
||||
string op_name;
|
||||
std::vector<string> arg_types;
|
||||
std::vector<string> arg_names;
|
||||
std::vector<string> output_types;
|
||||
std::vector<string> output_names;
|
||||
std::vector<bool> is_list_output;
|
||||
bool has_optional_attrs;
|
||||
string comment;
|
||||
|
||||
const OpDef& op_def;
|
||||
std::unordered_map<string, string> inferred_input_attrs;
|
||||
};
|
||||
|
||||
OpInfo::OpInfo(const OpDef& op_def) : op_def(op_def) {
|
||||
op_name = op_def.name();
|
||||
InferOpAttributes(op_def, &inferred_input_attrs);
|
||||
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
|
||||
arg_types.push_back("const ::tensorflow::Scope&");
|
||||
arg_names.push_back("scope");
|
||||
|
||||
if (op_def.summary().empty()) {
|
||||
comment = "TODO: add doc.\n";
|
||||
} else {
|
||||
comment = strings::StrCat(op_def.summary(), "\n");
|
||||
if (op_def.has_deprecation()) {
|
||||
strings::StrAppend(&comment, "\nDEPRECATED at GraphDef version ",
|
||||
op_def.deprecation().version(), ":\n",
|
||||
op_def.deprecation().explanation(), ".\n");
|
||||
}
|
||||
if (!op_def.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
|
||||
}
|
||||
}
|
||||
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
|
||||
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
arg_types.push_back(strings::StrCat(
|
||||
"::tensorflow::ops::", ArgIsList(arg) ? "InputList" : "Input"));
|
||||
arg_names.push_back(AvoidCPPKeywords(arg.name()));
|
||||
|
||||
// TODO(keveman): Include input type information.
|
||||
StringPiece description = arg.description();
|
||||
if (!description.empty()) {
|
||||
ConsumeEquals(&description);
|
||||
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ",
|
||||
arg.description(), "\n");
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
// If the attr is going to be inferred or is optional, don't add it as a
|
||||
// required argument.
|
||||
if ((inferred_input_attrs.find(attr.name()) !=
|
||||
inferred_input_attrs.end()) ||
|
||||
attr.has_default_value()) {
|
||||
continue;
|
||||
}
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
|
||||
arg_types.push_back(strings::StrCat(use_const ? "const " : "",
|
||||
attr_type_name, use_const ? "&" : ""));
|
||||
arg_names.push_back(AvoidCPPKeywords(attr.name()));
|
||||
if (!attr.description().empty()) {
|
||||
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(attr.name()), ":\n");
|
||||
// TODO(keveman): Word wrap and indent this, to handle multi-line
|
||||
// descriptions.
|
||||
strings::StrAppend(&comment, " ", attr.description(), "\n");
|
||||
}
|
||||
}
|
||||
comment = MakeComment(comment, "");
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
const auto& arg = op_def.output_arg(i);
|
||||
bool is_list = ArgIsList(arg);
|
||||
output_types.push_back(strings::StrCat("::tensorflow::ops::",
|
||||
is_list ? "OutputList" : "Output"));
|
||||
output_names.push_back(AvoidCPPKeywords(arg.name()));
|
||||
is_list_output.push_back(is_list);
|
||||
}
|
||||
}
|
||||
|
||||
string OpInfo::GetOpAttrStruct() const {
|
||||
string struct_fields;
|
||||
string setters;
|
||||
string attrs_comment = strings::StrCat("Optional attribute setters for ",
|
||||
op_def.name(), " :\n\n");
|
||||
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
// If attr will be inferred or it doesn't have a default value, don't
|
||||
// add it to the struct.
|
||||
if ((inferred_input_attrs.find(attr.name()) !=
|
||||
inferred_input_attrs.end()) ||
|
||||
!attr.has_default_value()) {
|
||||
continue;
|
||||
}
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
const string camel_case_name = ToCamelCase(attr.name());
|
||||
const string suffix =
|
||||
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
|
||||
const string attr_func_def =
|
||||
strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "",
|
||||
attr_type_name, use_const ? "&" : "");
|
||||
|
||||
strings::StrAppend(&attrs_comment, attr_func_def, "): Defaults to ",
|
||||
SummarizeAttrValue(attr.default_value()), "\n");
|
||||
if (!attr.description().empty()) {
|
||||
// TODO(keveman): Word wrap and indent this to handle multi-line
|
||||
// description.
|
||||
strings::StrAppend(&attrs_comment, " ", attr.description(), "\n");
|
||||
}
|
||||
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
|
||||
strings::StrAppend(&setters, " Attrs ret = *this;\n");
|
||||
strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
|
||||
strings::StrAppend(&setters, " return ret;\n }\n\n");
|
||||
|
||||
strings::StrAppend(
|
||||
&struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ",
|
||||
PrintAttrValue(op_def.name(), attr.default_value()), ";\n");
|
||||
}
|
||||
|
||||
if (struct_fields.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
string struct_decl = MakeComment(attrs_comment, " ");
|
||||
strings::StrAppend(&struct_decl, " struct Attrs {\n");
|
||||
strings::StrAppend(&struct_decl, setters, struct_fields);
|
||||
strings::StrAppend(&struct_decl, " };\n");
|
||||
|
||||
return struct_decl;
|
||||
}
|
||||
|
||||
string OpInfo::GetConstructorDecl(StringPiece op_name_prefix,
|
||||
bool include_attr) const {
|
||||
const string prefix = strings::StrCat(op_name_prefix, op_name, "(");
|
||||
string c_decl;
|
||||
for (int i = 0; i < arg_types.size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&c_decl, ", ");
|
||||
strings::StrAppend(&c_decl, arg_types[i], " ", arg_names[i]);
|
||||
}
|
||||
if (include_attr && has_optional_attrs) {
|
||||
strings::StrAppend(&c_decl, ", const ", op_name, "::Attrs& attrs");
|
||||
}
|
||||
strings::StrAppend(&c_decl, ")");
|
||||
return WordWrap(prefix, c_decl, kRightMargin);
|
||||
}
|
||||
|
||||
void OpInfo::WriteClassDecl(WritableFile* h) const {
|
||||
string class_decl = comment;
|
||||
strings::StrAppend(&class_decl, "class ", op_name, " {\n");
|
||||
strings::StrAppend(&class_decl, " public:\n");
|
||||
if (has_optional_attrs) {
|
||||
strings::StrAppend(&class_decl, GetOpAttrStruct());
|
||||
}
|
||||
strings::StrAppend(&class_decl, " ",
|
||||
GetConstructorDecl("", /* include_attr */ false), ";\n");
|
||||
if (has_optional_attrs) {
|
||||
strings::StrAppend(&class_decl, " ",
|
||||
GetConstructorDecl("", /* include_attr */ true), ";\n");
|
||||
}
|
||||
if (output_types.empty()) {
|
||||
// Allow casting this class to Operation.
|
||||
strings::StrAppend(&class_decl,
|
||||
" operator ::tensorflow::ops::Operation() const { "
|
||||
"return operation; }\n");
|
||||
} else if (output_types.size() == 1) {
|
||||
if (is_list_output[0]) {
|
||||
// Write the subscript operator, allowing out[i] for the list-typed
|
||||
// output.
|
||||
strings::StrAppend(&class_decl,
|
||||
" ::tensorflow::ops::Output operator[](size_t index) "
|
||||
"const { return ",
|
||||
output_names[0], "[index]; }\n\n");
|
||||
|
||||
} else {
|
||||
// Write type cast functions, allowing casting this class to Input and
|
||||
// Output.
|
||||
strings::StrAppend(
|
||||
&class_decl, " operator ::tensorflow::ops::Output() const { return ",
|
||||
output_names[0], "; }\n");
|
||||
strings::StrAppend(
|
||||
&class_decl, " operator ::tensorflow::ops::Input() const { return ",
|
||||
output_names[0], "; }\n");
|
||||
// Write node() to get the Node* directly.
|
||||
strings::StrAppend(&class_decl,
|
||||
" ::tensorflow::Node* node() const { return ",
|
||||
output_names[0], ".node(); }\n");
|
||||
}
|
||||
}
|
||||
// Add the static functions to set optional attrs
|
||||
if (has_optional_attrs) {
|
||||
strings::StrAppend(&class_decl, "\n");
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
if ((inferred_input_attrs.find(attr.name()) !=
|
||||
inferred_input_attrs.end()) ||
|
||||
!attr.has_default_value()) {
|
||||
continue;
|
||||
}
|
||||
const auto entry = AttrTypeName(attr.type());
|
||||
const auto attr_type_name = entry.first;
|
||||
const bool use_const = entry.second;
|
||||
const string camel_case_name = ToCamelCase(attr.name());
|
||||
const string suffix =
|
||||
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
|
||||
const string attr_func_def = strings::StrCat(
|
||||
camel_case_name, suffix, "(", use_const ? "const " : "",
|
||||
attr_type_name, use_const ? "&" : "");
|
||||
strings::StrAppend(&class_decl, " static Attrs ", attr_func_def,
|
||||
" x) {\n");
|
||||
strings::StrAppend(&class_decl, " return Attrs().", camel_case_name,
|
||||
suffix, "(x);\n");
|
||||
strings::StrAppend(&class_decl, " }\n");
|
||||
}
|
||||
}
|
||||
|
||||
strings::StrAppend(&class_decl, "\n");
|
||||
|
||||
if (output_types.empty()) {
|
||||
strings::StrAppend(&class_decl, " Operation operation;\n");
|
||||
}
|
||||
for (int i = 0; i < output_types.size(); ++i) {
|
||||
strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
|
||||
";\n");
|
||||
}
|
||||
|
||||
strings::StrAppend(&class_decl, "};\n\n");
|
||||
TF_CHECK_OK(h->Append(class_decl));
|
||||
}
|
||||
|
||||
void OpInfo::GetOutput(string* out) const {
|
||||
const string scope_str = arg_names[0];
|
||||
string return_on_error =
|
||||
strings::StrCat("if (!", scope_str, ".ok()) return;");
|
||||
|
||||
// No outputs.
|
||||
if (op_def.output_arg_size() == 0) {
|
||||
strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
|
||||
return;
|
||||
}
|
||||
if (op_def.output_arg_size() == 1) {
|
||||
// One output, no need for NameRangeMap
|
||||
if (is_list_output[0]) {
|
||||
strings::StrAppend(out,
|
||||
" for (int64 i = 0; i < ret->num_outputs(); ++i)\n");
|
||||
strings::StrAppend(out, " this->", output_names[0],
|
||||
".push_back(Output(ret, i));\n");
|
||||
} else {
|
||||
strings::StrAppend(out, " this->", output_names[0],
|
||||
" = Output(ret, 0);\n");
|
||||
}
|
||||
return;
|
||||
}
|
||||
strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n");
|
||||
strings::StrAppend(
|
||||
out,
|
||||
" ::tensorflow::Status _status_ = "
|
||||
"::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), "
|
||||
"nullptr, &_outputs_range);\n");
|
||||
strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str,
|
||||
".UpdateStatus(_status_);\n", " return;\n");
|
||||
strings::StrAppend(out, " }\n\n");
|
||||
|
||||
for (int i = 0; i < op_def.output_arg_size(); ++i) {
|
||||
const string arg_range = strings::StrCat(
|
||||
"_outputs_range[\"", op_def.output_arg(i).name(), "\"]");
|
||||
if (is_list_output[i]) {
|
||||
strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ",
|
||||
arg_range, ".second; ++i)\n");
|
||||
strings::StrAppend(out, " this->", output_names[i],
|
||||
".push_back(Output(ret, i));\n");
|
||||
} else {
|
||||
strings::StrAppend(out, " this->", output_names[i], " = Output(ret, ",
|
||||
arg_range, ".first);\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string OpInfo::GetConstructorBody() const {
|
||||
const string scope_str = arg_names[0];
|
||||
|
||||
string body;
|
||||
string return_on_error =
|
||||
strings::StrCat("if (!", scope_str, ".ok()) return;");
|
||||
|
||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::",
|
||||
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(",
|
||||
scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n");
|
||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||
}
|
||||
|
||||
strings::StrAppend(&body, " ::tensorflow::Node* ret;\n");
|
||||
strings::StrAppend(&body, " const auto unique_name = ", scope_str,
|
||||
".GetUniqueNameForOp(\"", op_def.name(), "\");\n");
|
||||
strings::StrAppend(
|
||||
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
|
||||
op_def.name(), "\")\n");
|
||||
const string spaces = " ";
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
|
||||
}
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
|
||||
continue;
|
||||
}
|
||||
const string attr_name = attr.has_default_value()
|
||||
? strings::StrCat("attrs.", attr.name(), "_")
|
||||
: AvoidCPPKeywords(attr.name());
|
||||
strings::StrAppend(&body, spaces, ".Attr(\"", attr.name(), "\", ",
|
||||
attr_name, ")\n");
|
||||
}
|
||||
strings::StrAppend(&body, " ;\n");
|
||||
strings::StrAppend(&body, " ", scope_str, ".UpdateBuilder(&builder);\n");
|
||||
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
|
||||
scope_str, ".graph(), &ret));\n");
|
||||
GetOutput(&body);
|
||||
return body;
|
||||
}
|
||||
|
||||
void OpInfo::WriteClassDef(WritableFile* cc) const {
|
||||
string class_def;
|
||||
strings::StrAppend(&class_def,
|
||||
GetConstructorDecl(strings::StrCat(op_name, "::"),
|
||||
/* include_attr */ true),
|
||||
" {\n");
|
||||
strings::StrAppend(&class_def, GetConstructorBody());
|
||||
strings::StrAppend(&class_def, "}\n\n");
|
||||
|
||||
if (has_optional_attrs) {
|
||||
strings::StrAppend(&class_def,
|
||||
GetConstructorDecl(strings::StrCat(op_name, "::"),
|
||||
/* include_attr */ false));
|
||||
strings::StrAppend(&class_def, "\n : ", op_name, "(");
|
||||
int i = 0;
|
||||
for (; i < arg_names.size(); ++i) {
|
||||
if (i > 0) strings::StrAppend(&class_def, ", ");
|
||||
strings::StrAppend(&class_def, arg_names[i]);
|
||||
}
|
||||
if (i > 0) strings::StrAppend(&class_def, ", ");
|
||||
strings::StrAppend(&class_def, op_name, "::Attrs()");
|
||||
strings::StrAppend(&class_def, ") {}\n\n");
|
||||
}
|
||||
TF_CHECK_OK(cc->Append(class_def));
|
||||
}
|
||||
|
||||
void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
|
||||
OpInfo op_info(op_def);
|
||||
|
||||
op_info.WriteClassDecl(h);
|
||||
op_info.WriteClassDef(cc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
|
||||
const std::string& dot_cc_fname) {
|
||||
Env* env = Env::Default();
|
||||
std::unique_ptr<WritableFile> h = nullptr;
|
||||
std::unique_ptr<WritableFile> cc = nullptr;
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
|
||||
|
||||
const string header =
|
||||
R"header(// This file is MACHINE GENERATED! Do not edit.
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
)header";
|
||||
|
||||
// TODO(keveman): Make namespaces configurable.
|
||||
const string namespace_begin = R"namespace(
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
)namespace";
|
||||
|
||||
const string footer = R"footer(} // namespace ops
|
||||
} // namespace tensorflow
|
||||
)footer";
|
||||
|
||||
const string op_header = GetPath(dot_h_fname);
|
||||
const string op_header_guard = ToGuard(op_header);
|
||||
const string cc_header = strings::StrCat(
|
||||
R"include(// This file is MACHINE GENERATED! Do not edit.
|
||||
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
)include",
|
||||
"#include \"", op_header, "\"\n", namespace_begin);
|
||||
|
||||
TF_CHECK_OK(h->Append(
|
||||
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
|
||||
"#ifndef ",
|
||||
op_header_guard,
|
||||
"\n"
|
||||
"#define ",
|
||||
op_header_guard, "\n\n")));
|
||||
TF_CHECK_OK(h->Append(header));
|
||||
TF_CHECK_OK(h->Append(namespace_begin));
|
||||
TF_CHECK_OK(cc->Append(cc_header));
|
||||
|
||||
for (const auto& op_def : ops.op()) {
|
||||
WriteCCOp(op_def, h.get(), cc.get());
|
||||
}
|
||||
|
||||
TF_CHECK_OK(h->Append(footer));
|
||||
TF_CHECK_OK(
|
||||
h->Append(strings::StrCat("\n#endif ", "// ", op_header_guard, "\n")));
|
||||
TF_CHECK_OK(cc->Append(footer));
|
||||
|
||||
TF_CHECK_OK(cc->Close());
|
||||
TF_CHECK_OK(h->Close());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_OPS_CC_OP_GEN_H_
|
||||
#define TENSORFLOW_CC_OPS_CC_OP_GEN_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
|
||||
@ -26,4 +26,4 @@ void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_OPS_CC_OP_GEN_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/cc_op_gen.h"
|
||||
#include "tensorflow/cc/framework/cc_op_gen.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
229
tensorflow/cc/framework/cc_ops_test.cc
Normal file
229
tensorflow/cc/framework/cc_ops_test.cc
Normal file
@ -0,0 +1,229 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/cc/ops/test_op.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
Output Linear(const Scope& scope, Input x, Input w, Input b) {
|
||||
auto cop_scopes = scope.GetCompositeOpScopes("linear");
|
||||
auto m = MatMul(cop_scopes.child, x, w);
|
||||
return BiasAdd(cop_scopes.last, m, b);
|
||||
}
|
||||
|
||||
void GetTensors(const Scope& scope, OutputList tensors,
|
||||
std::vector<Tensor>* out) {
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
GraphDef def;
|
||||
scope.graph()->ToGraphDef(&def);
|
||||
|
||||
graph::SetDefaultDevice("/cpu:0", &def);
|
||||
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
std::vector<string> names;
|
||||
for (const auto& t : tensors) {
|
||||
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
|
||||
}
|
||||
TF_CHECK_OK(session->Run({}, names, {}, out));
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
|
||||
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
||||
std::vector<Tensor> outputs;
|
||||
GetTensors(scope, {tensor}, &outputs);
|
||||
*out = outputs[0];
|
||||
}
|
||||
|
||||
void GetColocationConstraints(Output tensor, std::vector<string>* constraints) {
|
||||
constraints->clear();
|
||||
const auto& attrs = tensor.op().node()->def().attr();
|
||||
ASSERT_TRUE(attrs.find("_class") != attrs.end());
|
||||
auto loc = attrs.find("_class")->second;
|
||||
TF_EXPECT_OK(AttrValueHasType(loc, "list(string)"));
|
||||
if (loc.value_case() == AttrValue::kList && loc.list().s_size() > 0) {
|
||||
for (int i = 0; i < loc.list().s_size(); ++i) {
|
||||
if (loc.list().s(i).find("loc:@") == 0) {
|
||||
constraints->push_back(loc.list().s(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(CCOpTest, Basic) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = Const(root, {{1, 1}});
|
||||
// NOTE: The recommended style for constructing ops is
|
||||
// auto v = OpConstructor(t0, t1, ..);
|
||||
// Since the wrappers are implemented as one class per op, the following
|
||||
// style is also possible :
|
||||
// PrimitiveOp p(t0, t1, ...);
|
||||
// It's being used here ONLY to ensure that, that style is tested.
|
||||
MatMul m(root, c, {{41}, {1}});
|
||||
TF_EXPECT_OK(root.status());
|
||||
Tensor out;
|
||||
GetTensor(root, m, &out);
|
||||
test::ExpectTensorEqual<int>(out, test::AsTensor<int>({42}, {1, 1}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, Attrs) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto m = MatMul(root, {{1}, {1}}, {{41}, {1}}, MatMul::TransposeA(true));
|
||||
TF_EXPECT_OK(root.status());
|
||||
Tensor out;
|
||||
GetTensor(root, m, &out);
|
||||
test::ExpectTensorEqual<int>(out, test::AsTensor<int>({42}, {1, 1}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, SplitConcat) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Split p(root, 0, {{1}, {2}}, 2);
|
||||
auto c = Concat(root, 0, {p[0], p[1]});
|
||||
TF_EXPECT_OK(root.status());
|
||||
Tensor out;
|
||||
GetTensor(root, c, &out);
|
||||
test::ExpectTensorEqual<int>(out, test::AsTensor<int>({1, 2}, {2, 1}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, CompositeOp) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto l = Linear(root.WithOpName("layer0"), {{10.0f, -3.0f}},
|
||||
{{.8f, .5f}, {.1f, .6f}}, {-8.0f, 31.0f});
|
||||
TF_EXPECT_OK(root.status());
|
||||
EXPECT_EQ(l.node()->name(), "layer0");
|
||||
Tensor out;
|
||||
GetTensor(root, l, &out);
|
||||
test::ExpectClose(out, test::AsTensor<float>({-0.3, 34.2}, {1, 2}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, MultiOutput) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto u = Unique(root, {1, 2, 2, 4, 3, 2});
|
||||
std::vector<Tensor> outputs;
|
||||
GetTensors(root, {u.y, u.idx}, &outputs);
|
||||
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({1, 2, 4, 3}));
|
||||
test::ExpectTensorEqual<int>(outputs[1],
|
||||
test::AsTensor<int>({0, 1, 1, 2, 3, 1}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, ExampleTrainer) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
// a = [3 2; -1 0]
|
||||
auto a = Const(root, {{3.f, 2.f}, {-1.f, 0.f}});
|
||||
// x = [1.0; 1.0]
|
||||
auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
|
||||
// y = a * x
|
||||
auto y = MatMul(root.WithOpName("y"), a, x);
|
||||
// y2 = y.^2
|
||||
auto y2 = Square(root, y);
|
||||
// y2_sum = sum(y2)
|
||||
auto y2_sum = Sum(root, y2, 0);
|
||||
// y_norm = sqrt(y2_sum)
|
||||
auto y_norm = Sqrt(root, y2_sum);
|
||||
// y_normalized = y ./ y_norm
|
||||
auto y_normalized = Div(root.WithOpName("y_normalized"), y, y_norm);
|
||||
Tensor out;
|
||||
GetTensor(root, y_normalized, &out);
|
||||
test::ExpectTensorNear<float>(
|
||||
out, test::AsTensor<float>({0.98058069, -0.19611613}, {2, 1}), 1e-5);
|
||||
}
|
||||
|
||||
TEST(CCOpTest, ThrowAwayOp) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
ThrowAway1(root, 1, 2.3f, 1, 1, 1, ThrowAway1::Builder(42));
|
||||
ThrowAway2(root, ThrowAway2::ThrowAway2_(3).Scope(1));
|
||||
TF_EXPECT_OK(root.status());
|
||||
}
|
||||
|
||||
TEST(CCOpTest, ControlDeps) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto v = Variable(root, {}, DT_FLOAT);
|
||||
auto assign = Assign(root, v, 41.0f);
|
||||
Scope with_control_deps = root.WithControlDependencies(assign);
|
||||
auto add = Add(with_control_deps, v, 1.0f);
|
||||
Scope no_control_deps = with_control_deps.WithNoControlDependencies();
|
||||
auto sub = Sub(no_control_deps, 3.0f, 2.0f);
|
||||
auto is_inited =
|
||||
IsVariableInitialized(no_control_deps.WithControlDependencies(sub), v);
|
||||
|
||||
TF_EXPECT_OK(root.status());
|
||||
|
||||
std::vector<Tensor> out;
|
||||
|
||||
GetTensors(root, {add}, &out);
|
||||
test::ExpectTensorNear<float>(out[0], test::AsTensor<float>({42.0f}, {}),
|
||||
1e-5);
|
||||
|
||||
out.clear();
|
||||
// Note : GetTensors creates a new session, so 'v' is uninitialized.
|
||||
// sub should have no control deps, so it should not cause the assign to run.
|
||||
// Hence is_inited should be false.
|
||||
GetTensors(root, {sub, is_inited}, &out);
|
||||
test::ExpectTensorNear<float>(out[0], test::AsTensor<float>({1.0f}, {}),
|
||||
1e-5);
|
||||
test::ExpectTensorEqual<bool>(out[1], test::AsTensor<bool>({false}, {}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, KernelLabel) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f);
|
||||
TF_EXPECT_OK(root.status());
|
||||
const auto& attrs = add.z.op().node()->def().attr();
|
||||
ASSERT_TRUE(attrs.find("_kernel") != attrs.end());
|
||||
auto kernel_attr = attrs.find("_kernel")->second;
|
||||
TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string"));
|
||||
EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel");
|
||||
}
|
||||
|
||||
TEST(CCOpTest, ColocateWith) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c1 = Const(root.WithOpName("c1"), 1);
|
||||
auto c2 = Const(root.WithOpName("c2").ColocateWith(c1), 2);
|
||||
std::vector<string> constraints;
|
||||
GetColocationConstraints(c2, &constraints);
|
||||
EXPECT_EQ(constraints[0], "loc:@c1");
|
||||
|
||||
auto c3 = Const(root.WithOpName("c3").ColocateWith(c2), 3);
|
||||
GetColocationConstraints(c3, &constraints);
|
||||
EXPECT_EQ(constraints[0], "loc:@c1");
|
||||
|
||||
auto a = Const(root.WithOpName("a"), 4);
|
||||
auto c4 = Const(root.WithOpName("c4").ColocateWith(a), 5);
|
||||
GetColocationConstraints(c4, &constraints);
|
||||
EXPECT_EQ(constraints[0], "loc:@a");
|
||||
|
||||
auto c5 = Const(root.WithOpName("c5").ColocateWith(c3).ColocateWith(c4), 6);
|
||||
GetColocationConstraints(c5, &constraints);
|
||||
EXPECT_EQ(constraints[0], "loc:@a");
|
||||
EXPECT_EQ(constraints[1], "loc:@c1");
|
||||
|
||||
Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4);
|
||||
auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7);
|
||||
const auto& attrs = c6.op().node()->def().attr();
|
||||
EXPECT_TRUE(attrs.find("_class") == attrs.end());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
70
tensorflow/cc/framework/ops.cc
Normal file
70
tensorflow/cc/framework/ops.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
Input::Initializer::Initializer(
|
||||
const std::initializer_list<Input::Initializer>& v) {
|
||||
if (v.size() < 1) {
|
||||
// Empty initializer list defaults to float tensor with shape (0,)
|
||||
tensor = Tensor(DT_FLOAT, TensorShape{0});
|
||||
return;
|
||||
}
|
||||
auto const& first = *v.begin();
|
||||
// Check to make sure that the constituent Initializers are all the same
|
||||
// type and same shape.
|
||||
for (auto const& e : v) {
|
||||
if (e.tensor.dtype() != first.tensor.dtype()) {
|
||||
status = errors::InvalidArgument(
|
||||
"Initializer list components should all have the same type");
|
||||
return;
|
||||
}
|
||||
if (!TensorShape{e.tensor.shape()}.IsSameSize(
|
||||
TensorShape{first.tensor.shape()})) {
|
||||
status = errors::InvalidArgument(
|
||||
"Initializer list components should all have the same shape");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Form the new shape.
|
||||
TensorShape shape{static_cast<int64>(v.size())};
|
||||
shape.AppendShape(TensorShape{first.tensor.shape()});
|
||||
|
||||
Tensor t(first.tensor.dtype(), shape);
|
||||
|
||||
// Collate the constituent Tensors.
|
||||
size_t offset = 0;
|
||||
for (auto const& e : v) {
|
||||
Tensor elem = e.tensor;
|
||||
if (first.tensor.dtype() == DT_STRING) {
|
||||
for (int i = 0; i < elem.NumElements(); ++i) {
|
||||
t.flat<string>()(offset + i) = elem.flat<string>()(i);
|
||||
}
|
||||
offset += elem.NumElements();
|
||||
} else {
|
||||
std::copy_n(elem.tensor_data().data(), elem.TotalBytes(),
|
||||
const_cast<char*>(t.tensor_data().data()) + offset);
|
||||
offset += elem.TotalBytes();
|
||||
}
|
||||
}
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
261
tensorflow/cc/framework/ops.h
Normal file
261
tensorflow/cc/framework/ops.h
Normal file
@ -0,0 +1,261 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
// TBD(keveman): This is going to be moved to //third_party/tensorflow
|
||||
// eventually. Remove the NOLINT comment when moving.
|
||||
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
// Represents a node in the computation graph.
|
||||
class Operation {
|
||||
public:
|
||||
Operation() : node_(nullptr) {}
|
||||
explicit Operation(Node* n) : node_(n) {}
|
||||
|
||||
int num_outputs() const { return node_->num_outputs(); }
|
||||
DataType output_type(int o) const { return node_->output_type(o); }
|
||||
Node* node() const { return node_; }
|
||||
|
||||
private:
|
||||
Node* node_;
|
||||
};
|
||||
|
||||
// Represents a tensor value produced by an Operation.
|
||||
class Output {
|
||||
public:
|
||||
Output() = default;
|
||||
explicit Output(Node* n) : op_(n) {}
|
||||
Output(Node* n, int64 index) : op_(n), index_(index) {}
|
||||
Output(const Operation& op, int64 index) : op_(op), index_(index) {}
|
||||
|
||||
Operation op() const { return op_; }
|
||||
Node* node() const { return op().node(); }
|
||||
int64 index() const { return index_; }
|
||||
DataType type() const { return op_.output_type(index_); }
|
||||
|
||||
private:
|
||||
Operation op_ = Operation(nullptr);
|
||||
int64 index_ = 0;
|
||||
};
|
||||
|
||||
// Represents a tensor value that can be used as an operand to an Operation.
|
||||
class Input {
|
||||
public:
|
||||
// Initializer enables constructing an Input object from various kinds of C++
|
||||
// constants such as simple primitive constants and nested initializer lists
|
||||
// representing a multi-dimensional array. Initializer constructors are all
|
||||
// templates, so the aforementioned kinds of C++ constants can be used to
|
||||
// construct an Initializer. Intializer stores the value it got constructed
|
||||
// with in a Tensor object.
|
||||
struct Initializer {
|
||||
// Construct from a scalar value of an arithmetic type or a type that can be
|
||||
// converted to a string (eg. a string literal).
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_arithmetic<T>::value ||
|
||||
std::is_convertible<T, string>::value>::type>
|
||||
Initializer(const T& v) { // NOLINT(runtime/explicit)
|
||||
typedef typename RealType<T>::type RealT;
|
||||
Tensor t(DataTypeToEnum<RealT>::v(), TensorShape());
|
||||
t.flat<T>()(0) = RealT(v);
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
explicit Initializer(const Tensor& t) : tensor(t) {}
|
||||
|
||||
// Construct from a scalar value and an explicit shape
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_arithmetic<T>::value ||
|
||||
std::is_convertible<T, string>::value>::type>
|
||||
Initializer(const T& v, const TensorShape& shape) {
|
||||
typedef typename RealType<T>::type RealT;
|
||||
Tensor t(DataTypeToEnum<RealT>::v(), shape);
|
||||
for (int64 i = 0; i < t.NumElements(); ++i) {
|
||||
t.flat<T>()(i) = RealT(v);
|
||||
}
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
// Construct from a initializer list of scalars (a one-dimensional tensor).
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_arithmetic<T>::value ||
|
||||
std::is_convertible<T, string>::value>::type>
|
||||
Initializer(
|
||||
const std::initializer_list<T>& v) { // NOLINT(runtime/explicit)
|
||||
typedef typename RealType<T>::type RealT;
|
||||
Tensor t(DataTypeToEnum<RealT>::v(),
|
||||
TensorShape{static_cast<int>(v.size())});
|
||||
std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
// Construct from a initializer list of scalars and an explicit shape.
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_arithmetic<T>::value ||
|
||||
std::is_convertible<T, string>::value>::type>
|
||||
Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
|
||||
typedef typename RealType<T>::type RealT;
|
||||
Tensor t(DataTypeToEnum<RealT>::v(), shape);
|
||||
if (t.NumElements() != v.size()) {
|
||||
status = errors::InvalidArgument(
|
||||
"Cannot construct a tensor with ", t.NumElements(),
|
||||
" from an initializer list with ", v.size(), " elements");
|
||||
return;
|
||||
}
|
||||
std::copy_n(v.begin(), v.size(), t.flat<RealT>().data());
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
// Construct a multi-dimensional tensor from a nested initializer list. Note
|
||||
// that C++ syntax allows nesting of arbitrarily typed intializer lists, so
|
||||
// such invalid initializers cannot be disallowed at compile time. This
|
||||
// function performs checks to make sure that the nested initializer list is
|
||||
// indeed a valid multi-dimensional tensor.
|
||||
Initializer(const std::initializer_list<Initializer>& v);
|
||||
|
||||
template <typename T, bool = std::is_convertible<T, string>::value>
|
||||
struct RealType {
|
||||
typedef string type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct RealType<T, false> {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
TensorProto AsTensorProto() {
|
||||
TensorProto tensor_proto;
|
||||
if (tensor.NumElements() > 1) {
|
||||
tensor.AsProtoTensorContent(&tensor_proto);
|
||||
} else {
|
||||
tensor.AsProtoField(&tensor_proto);
|
||||
}
|
||||
return tensor_proto;
|
||||
}
|
||||
|
||||
Status status;
|
||||
Tensor tensor;
|
||||
};
|
||||
|
||||
// All of Input's constructors are implicit. Input can be implicitly
|
||||
// constructed from the following objects :
|
||||
// * Output: This is so that the output of an Operation can be directly used
|
||||
// as the input to a op wrapper, which takes Inputs.
|
||||
// * A scalar, or a multi-dimensional tensor specified as a recursive
|
||||
// initializer list. This enables directly passing constants as
|
||||
// inputs to op wrappers.
|
||||
Input(const Output& o) : output_(o) {} // NOLINT(runtime/explicit)
|
||||
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_arithmetic<T>::value ||
|
||||
std::is_convertible<T, string>::value>::type>
|
||||
Input(const T& v) // NOLINT(runtime/explicit)
|
||||
: Input(Initializer(v)) {}
|
||||
|
||||
Input(const Initializer& init) // NOLINT(runtime/explicit)
|
||||
: status_(init.status),
|
||||
tensor_(init.tensor) {}
|
||||
|
||||
Input(const Tensor& t) // NOLINT(runtime/explicit)
|
||||
: status_(Status::OK()),
|
||||
tensor_(t) {}
|
||||
|
||||
Input(const std::initializer_list<Initializer>&
|
||||
init) { // NOLINT(runtime/explicit)
|
||||
for (const auto& i : init) {
|
||||
if (!i.status.ok()) {
|
||||
status_ = i.status;
|
||||
return;
|
||||
}
|
||||
}
|
||||
tensor_ = Initializer(init).tensor;
|
||||
}
|
||||
|
||||
// Constructor specifying a node name, index and datatype. This should only be
|
||||
// used for specifying a backward edge, needed by control flow.
|
||||
Input(const string& name, int i, DataType dt)
|
||||
: node_name_(name), index_(i), data_type_(dt) {}
|
||||
|
||||
Node* node() const { return output_.node(); }
|
||||
string node_name() const { return node_name_; }
|
||||
int index() const { return node_name_.empty() ? output_.index() : index_; }
|
||||
DataType data_type() const { return data_type_; }
|
||||
Status status() const { return status_; }
|
||||
const Tensor& tensor() const { return tensor_; }
|
||||
|
||||
private:
|
||||
Status status_;
|
||||
Output output_ = Output(Operation(nullptr), 0);
|
||||
Tensor tensor_;
|
||||
const string node_name_ = "";
|
||||
int index_ = 0;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
};
|
||||
|
||||
// A type for representing the output of ops that produce more than one output,
|
||||
// or a list of tensors.
|
||||
typedef std::vector<Output> OutputList;
|
||||
|
||||
// A type for representing the input to ops that require a list of tensors.
|
||||
class InputList {
|
||||
public:
|
||||
// Implicitly convert a list of outputs to a list of inputs. This is useful to
|
||||
// write code such as tf.Concat(tf.Split(x, 4)).
|
||||
InputList(const OutputList& out) { // NOLINT(runtime/explicit)
|
||||
for (auto const& x : out) {
|
||||
inputs_.push_back(x);
|
||||
}
|
||||
}
|
||||
|
||||
InputList(
|
||||
const std::initializer_list<Input>& inputs) // NOLINT(runtime/explicit)
|
||||
: inputs_(inputs.begin(), inputs.end()) {}
|
||||
|
||||
InputList(const tensorflow::gtl::ArraySlice<Input>&
|
||||
inputs) // NOLINT(runtime/explicit)
|
||||
: inputs_(inputs.begin(), inputs.end()) {}
|
||||
|
||||
InputList(
|
||||
const std::initializer_list<Output>& out) { // NOLINT(runtime/explicit)
|
||||
for (auto const& x : out) {
|
||||
inputs_.push_back(x);
|
||||
}
|
||||
}
|
||||
|
||||
typename std::vector<Input>::iterator begin() { return inputs_.begin(); }
|
||||
typename std::vector<Input>::iterator end() { return inputs_.end(); }
|
||||
typename std::vector<Input>::const_iterator begin() const {
|
||||
return inputs_.begin();
|
||||
}
|
||||
typename std::vector<Input>::const_iterator end() const {
|
||||
return inputs_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Input> inputs_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_
|
347
tensorflow/cc/framework/scope.cc
Normal file
347
tensorflow/cc/framework/scope.cc
Normal file
@ -0,0 +1,347 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map)
|
||||
: graph_(graph),
|
||||
status_(status),
|
||||
name_map_(name_map),
|
||||
scope_used_(nullptr) {}
|
||||
|
||||
Scope Scope::NewRootScope() {
|
||||
return Scope(new Graph(OpRegistry::Global()), new Status, new Scope::NameMap);
|
||||
}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::ScopeName, const string& name,
|
||||
bool copy_names)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(copy_names ? other.name_map_
|
||||
: std::shared_ptr<NameMap>(new NameMap)),
|
||||
scope_used_(nullptr),
|
||||
control_deps_(other.control_deps_),
|
||||
name_(name),
|
||||
op_name_(""),
|
||||
exit_on_error_(other.exit_on_error_),
|
||||
kernel_label_(other.kernel_label_),
|
||||
device_(other.device_),
|
||||
colocation_constraints_(other.colocation_constraints_) {}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::OpName, const string& name,
|
||||
const string& op_name)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(other.name_map_),
|
||||
scope_used_(other.scope_used_),
|
||||
control_deps_(other.control_deps_),
|
||||
name_(name),
|
||||
op_name_(op_name),
|
||||
exit_on_error_(other.exit_on_error_),
|
||||
kernel_label_(other.kernel_label_),
|
||||
device_(other.device_),
|
||||
colocation_constraints_(other.colocation_constraints_) {}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::ControlDeps,
|
||||
std::vector<ops::Operation> control_deps, bool clear_control_deps)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(other.name_map_),
|
||||
scope_used_(other.scope_used_),
|
||||
control_deps_(clear_control_deps
|
||||
? std::vector<ops::Operation>()
|
||||
: (control_deps.insert(control_deps.begin(),
|
||||
other.control_deps_.begin(),
|
||||
other.control_deps_.end()),
|
||||
control_deps)),
|
||||
name_(other.name_),
|
||||
op_name_(other.op_name_),
|
||||
exit_on_error_(other.exit_on_error_),
|
||||
kernel_label_(other.kernel_label_),
|
||||
device_(other.device_),
|
||||
colocation_constraints_(other.colocation_constraints_) {}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::Device, const string& device)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(other.name_map_),
|
||||
scope_used_(other.scope_used_),
|
||||
control_deps_(other.control_deps_),
|
||||
name_(other.name_),
|
||||
op_name_(other.op_name_),
|
||||
exit_on_error_(other.exit_on_error_),
|
||||
kernel_label_(other.kernel_label_),
|
||||
device_(device),
|
||||
colocation_constraints_(other.colocation_constraints_) {}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::SingleUseScope,
|
||||
const string& op_name)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(other.name_map_),
|
||||
scope_used_(new bool(false)),
|
||||
control_deps_(other.control_deps_),
|
||||
name_(other.name_),
|
||||
op_name_(op_name),
|
||||
exit_on_error_(other.exit_on_error_),
|
||||
kernel_label_(other.kernel_label_),
|
||||
device_(other.device_),
|
||||
colocation_constraints_(other.colocation_constraints_) {}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::ExitOnError)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(other.name_map_),
|
||||
scope_used_(other.scope_used_),
|
||||
control_deps_(other.control_deps_),
|
||||
name_(other.name_),
|
||||
op_name_(other.op_name_),
|
||||
exit_on_error_(true),
|
||||
kernel_label_(other.kernel_label_),
|
||||
device_(other.device_),
|
||||
colocation_constraints_(other.colocation_constraints_) {}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::KernelLabel,
|
||||
const string& kernel_label)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(other.name_map_),
|
||||
scope_used_(other.scope_used_),
|
||||
control_deps_(other.control_deps_),
|
||||
name_(other.name_),
|
||||
op_name_(other.op_name_),
|
||||
exit_on_error_(other.exit_on_error_),
|
||||
kernel_label_(kernel_label),
|
||||
device_(other.device_),
|
||||
colocation_constraints_(other.colocation_constraints_) {}
|
||||
|
||||
Scope::Scope(const Scope& other, Scope::Tags::Colocate,
|
||||
const ops::Operation& colocate_with_op, bool clear_colocations)
|
||||
: graph_(other.graph_),
|
||||
status_(other.status_),
|
||||
name_map_(other.name_map_),
|
||||
scope_used_(other.scope_used_),
|
||||
control_deps_(other.control_deps_),
|
||||
name_(other.name_),
|
||||
op_name_(other.op_name_),
|
||||
exit_on_error_(other.exit_on_error_),
|
||||
kernel_label_(other.kernel_label_),
|
||||
device_(other.device_),
|
||||
colocation_constraints_(
|
||||
clear_colocations
|
||||
? std::unordered_set<string>()
|
||||
: other.GetColocationConstraints(colocate_with_op)) {}
|
||||
|
||||
std::unordered_set<string> Scope::GetColocationConstraints(
|
||||
const ops::Operation& colocate_with_op) const {
|
||||
std::unordered_set<string> current_constraints(colocation_constraints_);
|
||||
const NodeDef& node_def = colocate_with_op.node()->def();
|
||||
if (node_def.attr().find("_class") != node_def.attr().end()) {
|
||||
const AttrValue& loc = node_def.attr().find("_class")->second;
|
||||
if (loc.value_case() == AttrValue::kList && loc.list().s_size() > 0) {
|
||||
for (int i = 0; i < loc.list().s_size(); ++i) {
|
||||
// Filter out the ones that don't have "loc:@" prefix
|
||||
if (loc.list().s(i).find("loc:@") == 0) {
|
||||
// Skip the "loc:@" prefix
|
||||
current_constraints.insert(loc.list().s(i).substr(5));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
current_constraints.insert(colocate_with_op.node()->name());
|
||||
}
|
||||
return current_constraints;
|
||||
}
|
||||
|
||||
void Scope::UpdateStatus(const Status s) const {
|
||||
status_->Update(s);
|
||||
if (exit_on_error_ && !status_->ok()) {
|
||||
LOG(FATAL) << status_;
|
||||
}
|
||||
}
|
||||
|
||||
Status Scope::ToGraphDef(GraphDef* gdef) const {
|
||||
if (!status_->ok()) {
|
||||
return *status_;
|
||||
}
|
||||
graph()->ToGraphDef(gdef);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Scope::ToGraph(Graph* g) const {
|
||||
if (status_->ok()) {
|
||||
GraphDef graph_def;
|
||||
graph()->ToGraphDef(&graph_def);
|
||||
GraphConstructorOptions opts;
|
||||
UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
|
||||
}
|
||||
return *status_;
|
||||
}
|
||||
|
||||
void Scope::UpdateBuilder(NodeBuilder* builder) const {
|
||||
std::vector<Node*> control_inputs;
|
||||
for (const auto& op : control_deps_) {
|
||||
control_inputs.push_back(op.node());
|
||||
}
|
||||
builder->ControlInputs(control_inputs);
|
||||
|
||||
if (!kernel_label_.empty()) {
|
||||
builder->Attr("_kernel", kernel_label_);
|
||||
}
|
||||
|
||||
if (!colocation_constraints_.empty()) {
|
||||
std::vector<string> constraints(colocation_constraints_.begin(),
|
||||
colocation_constraints_.end());
|
||||
// Sort the set.
|
||||
std::sort(constraints.begin(), constraints.end());
|
||||
// Add loc:@ prefix
|
||||
std::transform(constraints.begin(), constraints.end(), constraints.begin(),
|
||||
[](const string& s) { return strings::StrCat("loc:@", s); });
|
||||
builder->Attr("_class", constraints);
|
||||
}
|
||||
if (!device_.empty()) {
|
||||
builder->Device(device_);
|
||||
}
|
||||
}
|
||||
|
||||
string Scope::GetUniqueName(const string& prefix, bool check_single_use) const {
|
||||
if (check_single_use && single_use_scope()) {
|
||||
if (*scope_used_) {
|
||||
*status_ =
|
||||
errors::AlreadyExists(prefix, " already exists in the current scope");
|
||||
return "";
|
||||
}
|
||||
*scope_used_ = true;
|
||||
return prefix;
|
||||
}
|
||||
auto entry = name_map_->find(prefix);
|
||||
string unique_name = prefix;
|
||||
if (entry == name_map_->end()) {
|
||||
name_map_->insert({prefix, 0});
|
||||
} else {
|
||||
unique_name = strings::StrCat(unique_name, "_", ++entry->second);
|
||||
}
|
||||
return unique_name;
|
||||
}
|
||||
|
||||
string Scope::GetNameForOp(const string& default_name) const {
|
||||
const string unique_name =
|
||||
GetUniqueName(default_name, true /* check_single_use */);
|
||||
const string sep = name_.empty() || unique_name.empty() ? "" : "/";
|
||||
return strings::StrCat(name_, sep, unique_name);
|
||||
}
|
||||
|
||||
string Scope::GetUniqueNameForOp(const string& default_name) const {
|
||||
if (single_use_scope()) {
|
||||
if (op_name_.empty() || *scope_used_) {
|
||||
*status_ =
|
||||
errors::InvalidArgument("Cannot get a unique name in this scope");
|
||||
return "";
|
||||
}
|
||||
*scope_used_ = true;
|
||||
return op_name_;
|
||||
}
|
||||
return op_name_.empty() ? GetNameForOp(default_name) : GetNameForOp(op_name_);
|
||||
}
|
||||
|
||||
Scope Scope::NewSubScope(const string& child_scope_name) const {
|
||||
if (child_scope_name.empty()) {
|
||||
return Scope(*this, Scope::Tags::ScopeName(), name_, true /* copy_names */);
|
||||
}
|
||||
const string unique_name =
|
||||
GetUniqueName(child_scope_name, false /* check_single_use */);
|
||||
const string sep = name_.empty() || unique_name.empty() ? "" : "/";
|
||||
return Scope(*this, Scope::Tags::ScopeName(),
|
||||
strings::StrCat(name_, sep, unique_name),
|
||||
false /* copy_names */);
|
||||
}
|
||||
|
||||
Scope Scope::WithOpName(const string& op_name) const {
|
||||
if (single_use_scope()) {
|
||||
UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
|
||||
" on this scope"));
|
||||
return *this;
|
||||
}
|
||||
return Scope(*this, Scope::Tags::OpName(), name_, op_name);
|
||||
}
|
||||
|
||||
Scope Scope::WithControlDependencies(
|
||||
const gtl::ArraySlice<ops::Operation>& control_deps) const {
|
||||
return Scope(
|
||||
*this, Scope::Tags::ControlDeps(),
|
||||
std::vector<ops::Operation>(control_deps.begin(), control_deps.end()),
|
||||
/* clear_control_deps */ false);
|
||||
}
|
||||
|
||||
Scope Scope::WithControlDependencies(const ops::Output& control_dep) const {
|
||||
return Scope(*this, Scope::Tags::ControlDeps(),
|
||||
std::vector<ops::Operation>(1, control_dep.op()),
|
||||
/* clear_control_deps */ false);
|
||||
}
|
||||
|
||||
Scope Scope::WithNoControlDependencies() const {
|
||||
return Scope(*this, Scope::Tags::ControlDeps(), std::vector<ops::Operation>(),
|
||||
/* clear_control_deps */ true);
|
||||
}
|
||||
|
||||
Scope Scope::WithDevice(const string& device) const {
|
||||
return Scope(*this, Scope::Tags::Device(), device);
|
||||
}
|
||||
|
||||
Scope Scope::ColocateWith(const ops::Operation& op) const {
|
||||
return Scope(*this, Scope::Tags::Colocate(), op,
|
||||
/* clear_colocations */ false);
|
||||
}
|
||||
|
||||
Scope Scope::ClearColocation() const {
|
||||
return Scope(*this, Scope::Tags::Colocate(), ops::Operation(),
|
||||
/* clear_colocations */ true);
|
||||
}
|
||||
|
||||
Scope Scope::ExitOnError() const {
|
||||
return Scope(*this, Scope::Tags::ExitOnError());
|
||||
}
|
||||
|
||||
Scope Scope::WithKernelLabel(const string& kernel_label) const {
|
||||
return Scope(*this, Scope::Tags::KernelLabel(), kernel_label);
|
||||
}
|
||||
|
||||
CompositeOpScopes Scope::GetCompositeOpScopes(
|
||||
const string& composite_op_name) const {
|
||||
if (op_name_.empty() && composite_op_name.empty()) {
|
||||
UpdateStatus(errors::InvalidArgument(
|
||||
"Cannot create composite op scopes with empty name"));
|
||||
return {*this, *this};
|
||||
}
|
||||
if (!single_use_scope()) {
|
||||
Scope child = NewSubScope(op_name_.empty() ? composite_op_name : op_name_);
|
||||
const string child_op_sep = name_.empty() ? "" : "_";
|
||||
return {child, Scope(child, Scope::Tags::SingleUseScope(),
|
||||
strings::StrCat(name_, child_op_sep, child.name_))};
|
||||
} else {
|
||||
return {
|
||||
Scope(*this, Scope::Tags::ScopeName(), op_name_, true /* copy_names */),
|
||||
*this};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
260
tensorflow/cc/framework/scope.h
Normal file
260
tensorflow/cc/framework/scope.h
Normal file
@ -0,0 +1,260 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GraphDef;
|
||||
class NodeBuilder;
|
||||
struct CompositeOpScopes;
|
||||
|
||||
// A `Scope` object represents a set of related TensorFlow ops that have the
|
||||
// same properties such as a common name prefix.
|
||||
// A Scope object is a container for TensorFlow Op properties. Op constructors
|
||||
// get a Scope object as a mandatory first argument and the constructed op
|
||||
// acquires the properties in the object.
|
||||
//
|
||||
// A simple example:
|
||||
//
|
||||
// using namespace ops;
|
||||
// Scope root = Scope::NewRootScope();
|
||||
// auto c1 = Const(root, {{1, 1}});
|
||||
// auto m = MatMul(root, c1, {{41}, {1}});
|
||||
// GraphDef gdef;
|
||||
// Status s = root.ToGraphDef(&gdef);
|
||||
// if (!s.ok()) { /* Handle error */ }
|
||||
//
|
||||
// Scope hierarchy:
|
||||
// The Scope class provides various With<> functions that create a new scope.
|
||||
// The new scope typically has one property changed while other properties are
|
||||
// inherited from the parent scope.
|
||||
// NewSubScope(name) method appends `name` to the prefix of names for ops
|
||||
// created within the scope, and WithOpName() changes the suffix which
|
||||
// otherwise defaults to the type of the op.
|
||||
//
|
||||
// Name examples:
|
||||
// Scope root = Scope::NewRootScope();
|
||||
// Scope linear = root.NewSubScope("linear");
|
||||
// /* W will be named "linear/W" */
|
||||
// auto W = Variable(linear.WithOpName("W"),
|
||||
// {2, 2}, DT_FLOAT);
|
||||
// /* b will be named "linear/b" */
|
||||
// auto b = Variable(linear.WithOpName("b"),
|
||||
// {2}, DT_FLOAT);
|
||||
// auto x = Const(linear, {...}); // name: "linear/Const"
|
||||
// auto m = MatMul(linear, x, W); // name: "linear/MatMul"
|
||||
// auto r = BiasAdd(linear, m, b); // name: "linear/BiasAdd"
|
||||
//
|
||||
// Scope lifetime:
|
||||
// A new scope is created by calling Scope::NewRootScope. This creates some
|
||||
// resources that are shared by all the child scopes that inherit from this
|
||||
// scope, directly or transitively. For instance, a new scope creates a new
|
||||
// Graph object to which operations are added when the new scope or its children
|
||||
// are used by an Op constructor. The new scope also has a Status object which
|
||||
// will be used to indicate errors by Op-constructor functions called on any
|
||||
// child scope. The Op-constructor functions have to check the scope's status by
|
||||
// calling the ok() method before proceeding to construct the op.
|
||||
class Scope {
|
||||
public:
|
||||
// The following functions are for users making graphs. They return brand new
|
||||
// scopes, or scopes derived from an existing scope object.
|
||||
|
||||
// Return a new scope.
|
||||
// This creates a new graph and all operations constructed in this graph
|
||||
// should use the returned object as the "root" scope.
|
||||
static Scope NewRootScope();
|
||||
|
||||
// Return a new scope. Ops created with this scope will have
|
||||
// <name>/<child_scope_name> as the prefix. The actual name will be unique
|
||||
// in the current scope. All other properties are inherited from the current
|
||||
// scope. If child_scope_name is empty, the '/' is elided.
|
||||
Scope NewSubScope(const string& child_scope_name) const;
|
||||
|
||||
// Return a new scope. All ops created within the returned scope will have
|
||||
// names of the form <name>/<op_name>[_<suffix].
|
||||
Scope WithOpName(const string& op_name) const;
|
||||
|
||||
// Return a new scope. All ops created within the returned scope will have as
|
||||
// control dependencies the union of operations in the control_deps vector and
|
||||
// the control dependencies of the current scope.
|
||||
Scope WithControlDependencies(
|
||||
const gtl::ArraySlice<ops::Operation>& control_deps) const;
|
||||
// Same as above, but convenient to add control dependency on the operation
|
||||
// producing the control_dep output.
|
||||
Scope WithControlDependencies(const ops::Output& control_dep) const;
|
||||
|
||||
// Return a new scope. All ops created within the returned scope will have no
|
||||
// control dependencies on other operations.
|
||||
Scope WithNoControlDependencies() const;
|
||||
|
||||
// Return a new scope. All ops created within the returned scope will have the
|
||||
// device field set to 'device'.
|
||||
Scope WithDevice(const string& device) const;
|
||||
|
||||
// Return a new scope. All ops created within the returned scope will be
|
||||
// co-located on the device where op is placed.
|
||||
// NOTE: This function is intended to be use internal libraries only for
|
||||
// controlling placement of ops on to devices. Public use is not encouraged
|
||||
// because the implementation of device placement is subject to change.
|
||||
Scope ColocateWith(const ops::Operation& op) const;
|
||||
// Convenience function for above.
|
||||
Scope ColocateWith(const ops::Output& out) const {
|
||||
return ColocateWith(out.op());
|
||||
}
|
||||
// Clear all colocation constraints.
|
||||
Scope ClearColocation() const;
|
||||
|
||||
// Return a new scope. The op-constructor functions taking the returned scope
|
||||
// as the scope argument will exit as soon as an error is detected, instead of
|
||||
// setting the status on the scope.
|
||||
Scope ExitOnError() const;
|
||||
|
||||
// Return a new scope. All ops created with the new scope will have
|
||||
// kernel_label as the value for their '_kernel' attribute;
|
||||
Scope WithKernelLabel(const string& kernel_label) const;
|
||||
|
||||
// The following functions are for scope object consumers.
|
||||
|
||||
// Return a unique name, using default_name if an op name has not been
|
||||
// specified.
|
||||
string GetUniqueNameForOp(const string& default_name) const;
|
||||
|
||||
// Update the status on this scope.
|
||||
// Note: The status object is shared between all children of this scope.
|
||||
// If the resulting status is not Status::OK() and exit_on_error_ is set on
|
||||
// this scope, this function exits by calling LOG(FATAL).
|
||||
void UpdateStatus(const Status s) const;
|
||||
|
||||
// Update the builder with properties accumulated in this scope.
|
||||
void UpdateBuilder(NodeBuilder* builder) const;
|
||||
|
||||
CompositeOpScopes GetCompositeOpScopes(const string& composite_op_name) const;
|
||||
|
||||
bool ok() const { return status_->ok(); }
|
||||
|
||||
Graph* graph() const { return graph_.get(); }
|
||||
|
||||
Status status() const { return *status_; }
|
||||
|
||||
// If status() is Status::OK(), convert the Graph object stored in this scope
|
||||
// to a GraphDef proto and return Status::OK(). Otherwise, return the error
|
||||
// status as is without performing GraphDef conversion.
|
||||
Status ToGraphDef(GraphDef* gdef) const;
|
||||
|
||||
// If status() is Status::OK(), construct a Graph object using the default
|
||||
// GraphConstructorOptions, and return Status::OK if graph construction was
|
||||
// successful. Otherwise, return the error status.
|
||||
// TODO(josh11b, keveman): Make this faster; right now it converts
|
||||
// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
|
||||
// edges from the source and to the sink node, resolves back edges
|
||||
// by name), and makes sure the resulting graph is valid.
|
||||
Status ToGraph(Graph* g) const;
|
||||
|
||||
const std::vector<ops::Operation>& control_deps() const {
|
||||
return control_deps_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Tag types to choose the constructor to dispatch.
|
||||
struct Tags {
|
||||
enum class ScopeName;
|
||||
enum class OpName;
|
||||
enum class ControlDeps;
|
||||
enum class Device;
|
||||
enum class SingleUseScope;
|
||||
enum class ExitOnError;
|
||||
enum class KernelLabel;
|
||||
enum class Colocate;
|
||||
};
|
||||
|
||||
// A NameMap is used to keep track of suffixes for names used in a scope. A
|
||||
// name that has not been used so far in a scope will get no suffix. Later
|
||||
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
|
||||
// can be sharing the same NameMap. For instance, a new scope created using
|
||||
// WithControlDependencies() should would share the same NameMap with the
|
||||
// parent.
|
||||
typedef std::unordered_map<string, int> NameMap;
|
||||
|
||||
Scope(Graph* graph, Status* status, NameMap* name_map);
|
||||
Scope(const Scope& other, Tags::ScopeName, const string& name,
|
||||
bool copy_names);
|
||||
Scope(const Scope& other, Tags::OpName, const string& name,
|
||||
const string& op_name);
|
||||
Scope(const Scope& other, Tags::ControlDeps,
|
||||
std::vector<ops::Operation> control_deps, bool clear_control_deps);
|
||||
Scope(const Scope& other, Tags::Device, const string& device);
|
||||
Scope(const Scope& other, Tags::SingleUseScope, const string& op_name);
|
||||
Scope(const Scope& other, Tags::ExitOnError);
|
||||
Scope(const Scope& other, Tags::KernelLabel, const string& kernel_label);
|
||||
Scope(const Scope& other, Tags::Colocate,
|
||||
const ops::Operation& colocate_with_op, bool clear_colocations);
|
||||
|
||||
std::unordered_set<string> GetColocationConstraints(
|
||||
const ops::Operation& colocate_with_op) const;
|
||||
|
||||
// Helper functions to get a unique names.
|
||||
string GetUniqueName(const string& prefix, bool check_single_use) const;
|
||||
string GetNameForOp(const string& default_name) const;
|
||||
|
||||
bool single_use_scope() const { return scope_used_ != nullptr; }
|
||||
|
||||
// The graph, status, and name maps are shared by all child scopes
|
||||
// created from a single 'root' scope. A root scope is created by calling the
|
||||
// Scope::NewRootScope function, which creates a new graph, a new status and
|
||||
// the name maps.
|
||||
std::shared_ptr<Graph> graph_ = nullptr;
|
||||
std::shared_ptr<Status> status_ = nullptr;
|
||||
std::shared_ptr<NameMap> name_map_ = nullptr;
|
||||
|
||||
// If scope_used_ is not nullptr, op_name_ should be empty and
|
||||
// GetUniqueNameForOp can only be called once on this scope. More calls to
|
||||
// GetUniqueNameForOp will cause an error status to be set on this scope.
|
||||
std::shared_ptr<bool> scope_used_ = nullptr;
|
||||
|
||||
const std::vector<ops::Operation> control_deps_;
|
||||
|
||||
const string name_ = "";
|
||||
const string op_name_ = "";
|
||||
const bool exit_on_error_ = false;
|
||||
const string kernel_label_ = "";
|
||||
const string device_ = "";
|
||||
const std::unordered_set<string> colocation_constraints_;
|
||||
};
|
||||
|
||||
// A helper struct to hold the scopes that would be used by a function
|
||||
// constructing a composite op.
|
||||
struct CompositeOpScopes {
|
||||
// Scope to be used for creating the local ops (primitive or other composite
|
||||
// ops).
|
||||
Scope child;
|
||||
// Scope to be used for creating the last op.
|
||||
Scope last;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_
|
138
tensorflow/cc/framework/scope_test.cc
Normal file
138
tensorflow/cc/framework/scope_test.cc
Normal file
@ -0,0 +1,138 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(ScopeTest, BasicNames) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
EXPECT_EQ(root.GetUniqueNameForOp("add"), "add");
|
||||
EXPECT_EQ(root.GetUniqueNameForOp("add"), "add_1");
|
||||
EXPECT_EQ(root.GetUniqueNameForOp("add"), "add_2");
|
||||
EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul");
|
||||
}
|
||||
|
||||
TEST(ScopeTest, HierarchicalNames) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope child = root.NewSubScope("child");
|
||||
EXPECT_EQ(child.GetUniqueNameForOp("add"), "child/add");
|
||||
EXPECT_EQ(child.GetUniqueNameForOp("add"), "child/add_1");
|
||||
EXPECT_EQ(child.GetUniqueNameForOp("mul"), "child/mul");
|
||||
|
||||
Scope child_1 = root.NewSubScope("child");
|
||||
EXPECT_EQ(child_1.GetUniqueNameForOp("add"), "child_1/add");
|
||||
EXPECT_EQ(child_1.GetUniqueNameForOp("add"), "child_1/add_1");
|
||||
EXPECT_EQ(child_1.GetUniqueNameForOp("mul"), "child_1/mul");
|
||||
|
||||
Scope c_c = root.NewSubScope("c").NewSubScope("c");
|
||||
EXPECT_EQ(c_c.GetUniqueNameForOp("add"), "c/c/add");
|
||||
|
||||
Scope c_1 = root.NewSubScope("c");
|
||||
Scope c_1_c = c_1.NewSubScope("c");
|
||||
EXPECT_EQ(c_1_c.GetUniqueNameForOp("add"), "c_1/c/add");
|
||||
|
||||
Scope c_1_c_1 = c_1.NewSubScope("c");
|
||||
EXPECT_EQ(c_1_c_1.GetUniqueNameForOp("add"), "c_1/c_1/add");
|
||||
|
||||
EXPECT_EQ(root.NewSubScope("").NewSubScope("").GetUniqueNameForOp("d"), "d");
|
||||
EXPECT_EQ(root.NewSubScope("").GetUniqueNameForOp("d"), "d_1");
|
||||
EXPECT_EQ(root.GetUniqueNameForOp("d"), "d_2");
|
||||
}
|
||||
|
||||
TEST(ScopeTest, ScopeAndOpNames) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope child = root.NewSubScope("child");
|
||||
|
||||
EXPECT_EQ(child.GetUniqueNameForOp("add"), "child/add");
|
||||
EXPECT_EQ(root.GetUniqueNameForOp("child"), "child_1");
|
||||
|
||||
EXPECT_EQ(root.NewSubScope("child").GetUniqueNameForOp("p"), "child_2/p");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
string LastOp(const Scope& scope) { return scope.GetUniqueNameForOp("Last"); }
|
||||
|
||||
std::vector<string> AnotherCompositeOp(const Scope& scope) {
|
||||
auto cop_scopes = scope.GetCompositeOpScopes("another_cop");
|
||||
const string c1 = cop_scopes.child.GetUniqueNameForOp("c1");
|
||||
const string c2 = cop_scopes.child.GetUniqueNameForOp("mul");
|
||||
return {c1, c2, LastOp(cop_scopes.last)};
|
||||
}
|
||||
|
||||
std::vector<string> LinearOp(const Scope& scope) {
|
||||
auto cop_scopes = scope.GetCompositeOpScopes("linear");
|
||||
Scope linear = cop_scopes.child;
|
||||
const string mul_op_name = linear.GetUniqueNameForOp("mul");
|
||||
const string bias_add_op_name = linear.GetUniqueNameForOp("bias_add");
|
||||
auto cop_names = AnotherCompositeOp(cop_scopes.last);
|
||||
return {mul_op_name, bias_add_op_name, cop_names[0], cop_names[1],
|
||||
cop_names[2]};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(ScopeTest, CompositeOp) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
const auto names1 = LinearOp(root);
|
||||
|
||||
EXPECT_EQ(names1[0], "linear/mul");
|
||||
EXPECT_EQ(names1[1], "linear/bias_add");
|
||||
EXPECT_EQ(names1[2], "linear/c1");
|
||||
EXPECT_EQ(names1[3], "linear/mul_1");
|
||||
EXPECT_EQ(names1[4], "linear");
|
||||
|
||||
EXPECT_EQ(root.GetUniqueNameForOp("linear"), "linear_1");
|
||||
|
||||
const auto names2 = LinearOp(root);
|
||||
|
||||
EXPECT_EQ(names2[0], "linear_2/mul");
|
||||
EXPECT_EQ(names2[1], "linear_2/bias_add");
|
||||
EXPECT_EQ(names2[2], "linear_2/c1");
|
||||
EXPECT_EQ(names2[3], "linear_2/mul_1");
|
||||
EXPECT_EQ(names2[4], "linear_2");
|
||||
|
||||
const auto names3 = LinearOp(root.WithOpName("c"));
|
||||
|
||||
EXPECT_EQ(names3[0], "c/mul");
|
||||
EXPECT_EQ(names3[1], "c/bias_add");
|
||||
EXPECT_EQ(names3[2], "c/c1");
|
||||
EXPECT_EQ(names3[3], "c/mul_1");
|
||||
EXPECT_EQ(names3[4], "c");
|
||||
}
|
||||
|
||||
TEST(ScopeTest, SingleUseScope) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto cop_scopes = root.GetCompositeOpScopes("cop");
|
||||
// cop_scopes.last is a single use scope
|
||||
EXPECT_EQ(cop_scopes.last.GetUniqueNameForOp("foo"), "cop");
|
||||
cop_scopes.last.GetUniqueNameForOp("foo");
|
||||
// Error status should be set on cop_scopes.last
|
||||
EXPECT_FALSE(cop_scopes.last.ok());
|
||||
}
|
||||
|
||||
TEST(ScopeTest, ControlDeps) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c1 = ops::Operation();
|
||||
auto c2 = ops::Operation();
|
||||
Scope c = root.WithControlDependencies({c1, c2});
|
||||
EXPECT_EQ(c.control_deps().size(), 2);
|
||||
Scope c_c = c.WithControlDependencies({ops::Operation()});
|
||||
EXPECT_EQ(c_c.control_deps().size(), 3);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
47
tensorflow/cc/framework/test_op.cc
Normal file
47
tensorflow/cc/framework/test_op.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("ThrowAway1")
|
||||
.Input("ret: int32")
|
||||
.Input("unique_name: float")
|
||||
.Input("for: int32")
|
||||
.Attr("scope: int")
|
||||
.Attr("builder: int = 1")
|
||||
.Attr("while: int")
|
||||
.Doc(R"doc(
|
||||
Op to test keywords and reserved words in input and attr names.
|
||||
|
||||
ret: Return value.
|
||||
for: Keyword as name for input.
|
||||
while: Keyword as name for attr.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ThrowAway2")
|
||||
.Attr("scope: int = 2")
|
||||
.Attr("throw_away2: int = 2")
|
||||
.Attr("attrs: int = 4")
|
||||
.Attr("node: int = 4");
|
||||
|
||||
REGISTER_OP("ThrowAway3").Output("node: int32");
|
||||
|
||||
REGISTER_OP("ThrowAway4").Input("node: int32");
|
||||
|
||||
REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4");
|
||||
|
||||
} // namespace tensorflow
|
@ -1,382 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// TODO(josh11b): Rewrite function parameter names to avoid C++ keywords
|
||||
// or "opts".
|
||||
|
||||
#include "tensorflow/cc/ops/cc_op_gen.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
const int kRightMargin = 79;
|
||||
|
||||
const char* AttrTypeName(StringPiece attr_type) {
|
||||
static const char* kAttrTypeName[][2] = {
|
||||
{"string", "StringPiece"},
|
||||
{"list(string)", "gtl::ArraySlice<string>"},
|
||||
{"int", "int64"},
|
||||
{"list(int)", "gtl::ArraySlice<int>"},
|
||||
{"float", "float"},
|
||||
{"list(float)", "gtl::ArraySlice<float>"},
|
||||
{"bool", "bool"},
|
||||
{"list(bool)", "gtl::ArraySlice<bool>"},
|
||||
{"type", "DataType"},
|
||||
{"list(type)", "DataTypeSlice"},
|
||||
{"shape", "TensorShape"},
|
||||
{"list(shape)", "gtl::ArraySlice<TensorShape>"},
|
||||
{"tensor", "const Tensor&"},
|
||||
{"list(tensor)", "gtl::ArraySlice<Tensor>"},
|
||||
{"func", "const NameAttrList&"},
|
||||
};
|
||||
for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
|
||||
if (attr_type == kAttrTypeName[i][0]) {
|
||||
return kAttrTypeName[i][1];
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
|
||||
return "";
|
||||
}
|
||||
|
||||
// Change: Into:
|
||||
// ABC // ABC
|
||||
// //
|
||||
// DEF // DEF
|
||||
string MakeComment(StringPiece text) {
|
||||
string ret;
|
||||
while (!text.empty()) {
|
||||
int last_non_space = -1;
|
||||
int newline;
|
||||
for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
|
||||
if (text[newline] == '\n') break;
|
||||
if (text[newline] != ' ') last_non_space = newline;
|
||||
}
|
||||
if (last_non_space == -1) {
|
||||
strings::StrAppend(&ret, "//\n");
|
||||
} else {
|
||||
strings::StrAppend(&ret, "// ", text.substr(0, last_non_space + 1), "\n");
|
||||
}
|
||||
text.remove_prefix(newline + 1);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
|
||||
// TODO(josh11b): Better wrapping of comments.
|
||||
string comment;
|
||||
if (op_def.summary().empty()) {
|
||||
comment = "TODO: add doc.\n";
|
||||
} else {
|
||||
comment = strings::StrCat(op_def.summary(), "\n");
|
||||
if (op_def.has_deprecation()) {
|
||||
strings::StrAppend(&comment, "\nDEPRECATED at GraphDef version ",
|
||||
op_def.deprecation().version(), ":\n",
|
||||
op_def.deprecation().explanation(), ".\n");
|
||||
}
|
||||
if (!op_def.description().empty()) {
|
||||
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
|
||||
}
|
||||
}
|
||||
|
||||
static const string kSingleInputType = "NodeOut";
|
||||
static const string kListInputType = "gtl::ArraySlice<NodeOut>";
|
||||
|
||||
std::vector<string> arg_types;
|
||||
std::vector<string> arg_names;
|
||||
|
||||
strings::StrAppend(&comment, "\nArguments:\n");
|
||||
|
||||
// Map from attr name to the first input arg it is inferred from.
|
||||
std::unordered_map<string, string> inferred_attrs;
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
const auto& arg(op_def.input_arg(i));
|
||||
arg_names.emplace_back(arg.name());
|
||||
bool is_list = false;
|
||||
|
||||
if (!arg.type_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
|
||||
} else if (!arg.type_list_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
|
||||
arg.name());
|
||||
is_list = true;
|
||||
}
|
||||
if (!arg.number_attr().empty()) {
|
||||
gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
|
||||
is_list = true;
|
||||
}
|
||||
if (is_list) {
|
||||
arg_types.emplace_back(kListInputType);
|
||||
} else {
|
||||
arg_types.emplace_back(kSingleInputType);
|
||||
}
|
||||
|
||||
// TODO(josh11b): Include input type information.
|
||||
StringPiece description = arg.description();
|
||||
if (!description.empty()) {
|
||||
ConsumeEquals(&description);
|
||||
strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
|
||||
arg.description(), "\n");
|
||||
}
|
||||
}
|
||||
|
||||
string options_comment;
|
||||
for (int i = 0; i < op_def.attr_size(); ++i) {
|
||||
const auto& attr(op_def.attr(i));
|
||||
// Do not add inferred attrs or attrs with defaults to the C++
|
||||
// function signature.
|
||||
if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
|
||||
if (!attr.has_default_value()) {
|
||||
arg_names.emplace_back(attr.name());
|
||||
arg_types.emplace_back(AttrTypeName(attr.type()));
|
||||
if (!attr.description().empty()) {
|
||||
strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
|
||||
attr.description(), "\n");
|
||||
}
|
||||
} else {
|
||||
strings::StrAppend(&options_comment, " .WithAttr(\"", attr.name(),
|
||||
"\", ", AttrTypeName(attr.type()), "): Defaults to ",
|
||||
SummarizeAttrValue(attr.default_value()), ".\n");
|
||||
if (!attr.description().empty()) {
|
||||
strings::StrAppend(&options_comment, " ", attr.description(),
|
||||
"\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_EQ(arg_names.size(), arg_types.size());
|
||||
strings::StrAppend(&comment, "* opts:\n", options_comment,
|
||||
R"comment( .WithName(StringPiece): Set the Node's name
|
||||
.WithDevice(StringPiece): Set the Node's requested device
|
||||
.WithControlInput(Node*) / .WithControlInputs({Node*, ...}):
|
||||
Add control dependencies on the specified Node(s).
|
||||
|
||||
Returns a pointer to the created Node)comment");
|
||||
|
||||
// TODO(josh11b): Include output type information.
|
||||
if (op_def.output_arg_size() == 0) {
|
||||
strings::StrAppend(&comment, ".\n");
|
||||
} else if (op_def.output_arg_size() == 1) {
|
||||
StringPiece description = op_def.output_arg(0).description();
|
||||
ConsumeEquals(&description);
|
||||
if (description.empty()) {
|
||||
strings::StrAppend(&comment, ".\n");
|
||||
} else {
|
||||
strings::StrAppend(&comment, ", with output:\n", description, "\n");
|
||||
}
|
||||
} else {
|
||||
strings::StrAppend(&comment, ", with outputs:\n");
|
||||
for (int o = 0; o < op_def.output_arg_size(); ++o) {
|
||||
StringPiece description = op_def.output_arg(o).description();
|
||||
ConsumeEquals(&description);
|
||||
if (description.empty()) {
|
||||
strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), "\n");
|
||||
} else {
|
||||
strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), ": ",
|
||||
description, "\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the header comment.
|
||||
TF_CHECK_OK(h->Append(MakeComment(comment)));
|
||||
|
||||
// Declare the function wrapper.
|
||||
const string prefix = strings::StrCat("Node* ", op_def.name(), "(");
|
||||
string h_rest;
|
||||
for (size_t i = 0; i < arg_names.size(); ++i) {
|
||||
strings::StrAppend(&h_rest, arg_types[i], " ", arg_names[i], ", ");
|
||||
}
|
||||
strings::StrAppend(&h_rest, "const GraphDefBuilder::Options& opts");
|
||||
string cc_decl = h_rest;
|
||||
strings::StrAppend(&h_rest, ");");
|
||||
TF_CHECK_OK(h->Append(WordWrap(prefix, h_rest, kRightMargin) + "\n\n"));
|
||||
|
||||
// Define the function wrapper.
|
||||
strings::StrAppend(&cc_decl, ") {");
|
||||
TF_CHECK_OK(cc->Append(WordWrap(prefix, cc_decl, kRightMargin) + "\n"));
|
||||
const string op_name = strings::StrCat(" static const string kOpName = \"",
|
||||
op_def.name(), "\";\n");
|
||||
|
||||
if (arg_types.empty()) {
|
||||
TF_CHECK_OK(cc->Append(op_name));
|
||||
TF_CHECK_OK(cc->Append(" return SourceOp(kOpName, opts);\n}\n\n"));
|
||||
} else if (arg_types == std::vector<string>({kSingleInputType})) {
|
||||
TF_CHECK_OK(cc->Append(op_name));
|
||||
TF_CHECK_OK(cc->Append(strings::StrCat(" return UnaryOp(kOpName, ",
|
||||
arg_names[0], ", opts);\n}\n\n")));
|
||||
} else if (arg_types ==
|
||||
std::vector<string>({kSingleInputType, kSingleInputType})) {
|
||||
TF_CHECK_OK(cc->Append(op_name));
|
||||
// TODO(josh11b): Word wrap this if it ever becomes necessary.
|
||||
TF_CHECK_OK(
|
||||
cc->Append(strings::StrCat(" return BinaryOp(kOpName, ", arg_names[0],
|
||||
", ", arg_names[1], ", opts);\n}\n\n")));
|
||||
} else {
|
||||
TF_CHECK_OK(cc->Append(" if (opts.HaveError()) return nullptr;\n"));
|
||||
TF_CHECK_OK(cc->Append(op_name));
|
||||
TF_CHECK_OK(cc->Append(
|
||||
" NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,\n"
|
||||
" opts.op_registry());\n"));
|
||||
for (size_t i = 0; i < arg_names.size(); ++i) {
|
||||
if (i < static_cast<size_t>(op_def.input_arg_size())) {
|
||||
TF_CHECK_OK(cc->Append(
|
||||
strings::StrCat(" node_builder.Input(", arg_names[i], ");\n")));
|
||||
} else {
|
||||
TF_CHECK_OK(
|
||||
cc->Append(strings::StrCat(" node_builder.Attr(\"", arg_names[i],
|
||||
"\", ", arg_names[i], ");\n")));
|
||||
}
|
||||
}
|
||||
TF_CHECK_OK(
|
||||
cc->Append(" return opts.FinalizeBuilder(&node_builder);\n"
|
||||
"}\n\n"));
|
||||
}
|
||||
}
|
||||
|
||||
// Converts:
|
||||
// bazel-out/.../genfiles/(external/YYY/)?XX
|
||||
// to: XX.
|
||||
string GetPath(const std::string& dot_h_fname) {
|
||||
auto pos = dot_h_fname.find("/genfiles/");
|
||||
string result = dot_h_fname;
|
||||
if (pos != string::npos) {
|
||||
// - 1 account for the terminating null character (\0) in "/genfiles/".
|
||||
result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
|
||||
}
|
||||
if (result.size() > sizeof("external/") &&
|
||||
result.compare(0, sizeof("external/") - 1, "external/") == 0) {
|
||||
result = result.substr(sizeof("external/") - 1);
|
||||
pos = result.find("/");
|
||||
if (pos != string::npos) {
|
||||
result = result.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Converts:
|
||||
// cc/ops/gen_foo_ops.h
|
||||
// to:
|
||||
// CC_OPS_GEN_FOO_OPS_H_
|
||||
string ToGuard(const std::string& path) {
|
||||
string guard;
|
||||
guard.reserve(path.size() + 1); // + 1 -> trailing _
|
||||
for (const char c : path) {
|
||||
if (c >= 'A' && c <= 'Z') {
|
||||
guard += c;
|
||||
} else if (c >= 'a' && c <= 'z') {
|
||||
guard += c + 'A' - 'a';
|
||||
} else {
|
||||
guard += '_';
|
||||
}
|
||||
}
|
||||
guard += '_';
|
||||
return guard;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
|
||||
const std::string& dot_cc_fname) {
|
||||
Env* env = Env::Default();
|
||||
std::unique_ptr<WritableFile> h;
|
||||
std::unique_ptr<WritableFile> cc;
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
|
||||
TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
|
||||
|
||||
// .h Header
|
||||
const string include = GetPath(dot_h_fname);
|
||||
const string guard = ToGuard(include);
|
||||
// TODO(josh11b): Mention the library for which wrappers are being generated.
|
||||
Status s;
|
||||
s = h->Append(
|
||||
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
|
||||
"#ifndef ",
|
||||
guard,
|
||||
"\n"
|
||||
"#define ",
|
||||
guard, R"header(
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
// These add a node to the graph from opts.
|
||||
//
|
||||
// Note for "NodeOut" inputs, you will typically either pass
|
||||
// * a {Node*, int index} (to pass the index-th output of that node), or
|
||||
// * a Node* (to pass the first output of that node).
|
||||
|
||||
|
||||
)header"));
|
||||
TF_CHECK_OK(s);
|
||||
// .cc Header
|
||||
s = cc->Append(
|
||||
strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
|
||||
"#include \"",
|
||||
include, R"header("
|
||||
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
)header"));
|
||||
TF_CHECK_OK(s);
|
||||
|
||||
for (const auto& op_def : ops.op()) {
|
||||
WriteCCOp(op_def, h.get(), cc.get());
|
||||
}
|
||||
|
||||
// .h Footer
|
||||
|
||||
s = h->Append(strings::StrCat(R"footer(} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // )footer",
|
||||
guard, "\n"));
|
||||
TF_CHECK_OK(s);
|
||||
|
||||
// .cc Footer
|
||||
|
||||
s = cc->Append(R"footer(} // namespace ops
|
||||
} // namespace tensorflow
|
||||
)footer");
|
||||
TF_CHECK_OK(s);
|
||||
|
||||
TF_CHECK_OK(cc->Close());
|
||||
TF_CHECK_OK(h->Close());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,119 +14,59 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
namespace {
|
||||
const string& OpName() {
|
||||
static const string kOpName = "Const";
|
||||
return kOpName;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define DEFINE_CONST_SCALAR(TYPE) \
|
||||
Node* Const(TYPE s, const GraphDefBuilder::Options& options) { \
|
||||
return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \
|
||||
Output Const(const Scope& scope, const Input::Initializer& val) {
|
||||
if (!scope.ok()) return Output();
|
||||
if (!val.status.ok()) {
|
||||
scope.UpdateStatus(val.status);
|
||||
return Output();
|
||||
}
|
||||
|
||||
#define DEFINE_CONST_VECTOR(TYPE) \
|
||||
Node* Const(gtl::ArraySlice<TYPE> v, \
|
||||
const GraphDefBuilder::Options& options) { \
|
||||
return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \
|
||||
}
|
||||
Node* ret;
|
||||
Graph* graph = scope.graph();
|
||||
const string unique_name = scope.GetUniqueNameForOp("Const");
|
||||
auto builder = NodeBuilder(unique_name, "Const")
|
||||
.Attr("value", val.tensor)
|
||||
.Attr("dtype", val.tensor.dtype());
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(graph, &ret));
|
||||
|
||||
#define DEFINE_CONST_TENSOR(TYPE, ...) \
|
||||
Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
|
||||
const GraphDefBuilder::Options& options) { \
|
||||
if (options.HaveError()) return nullptr; \
|
||||
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), \
|
||||
options.op_registry()); \
|
||||
const DataType dt = DataTypeToEnum<TYPE>::v(); \
|
||||
if (t.size() == 1) { \
|
||||
TensorProto proto; \
|
||||
proto.set_dtype(dt); \
|
||||
shape.AsProto(proto.mutable_tensor_shape()); \
|
||||
__VA_ARGS__; \
|
||||
node_builder.Attr("dtype", dt).Attr("value", proto); \
|
||||
} else { \
|
||||
Tensor tensor(dt, shape); \
|
||||
if (tensor.NumElements() != static_cast<int64>(t.size())) { \
|
||||
options.UpdateStatus(errors::InvalidArgument( \
|
||||
t.size(), " values provided to Const() != ", tensor.NumElements(), \
|
||||
" elements for shape ", shape.DebugString())); \
|
||||
} else { \
|
||||
std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data()); \
|
||||
node_builder.Attr("dtype", dt).Attr("value", tensor); \
|
||||
} \
|
||||
} \
|
||||
return options.FinalizeBuilder(&node_builder); \
|
||||
}
|
||||
if (!scope.ok()) return Output();
|
||||
|
||||
#define DEFINE_CONST_IMPL(TYPE, ...) \
|
||||
DEFINE_CONST_SCALAR(TYPE) \
|
||||
DEFINE_CONST_VECTOR(TYPE) \
|
||||
DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__)
|
||||
|
||||
#define DEFINE_CONST(TYPE, FIELD) \
|
||||
DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());)
|
||||
|
||||
DEFINE_CONST(float, float_val);
|
||||
DEFINE_CONST(double, double_val);
|
||||
DEFINE_CONST(int32, int_val);
|
||||
DEFINE_CONST(uint8, int_val);
|
||||
DEFINE_CONST(int16, int_val);
|
||||
DEFINE_CONST(int8, int_val);
|
||||
DEFINE_CONST(int64, int64_val);
|
||||
DEFINE_CONST(bool, bool_val);
|
||||
|
||||
DEFINE_CONST_IMPL(Eigen::half, proto.add_half_val(t.begin()->x));
|
||||
|
||||
DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real());
|
||||
proto.add_scomplex_val(t.begin()->imag()););
|
||||
|
||||
DEFINE_CONST_IMPL(complex128, proto.add_dcomplex_val(t.begin()->real());
|
||||
proto.add_dcomplex_val(t.begin()->imag()););
|
||||
|
||||
Node* Const(StringPiece s, const GraphDefBuilder::Options& options) {
|
||||
if (options.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
|
||||
options.op_registry());
|
||||
TensorProto proto;
|
||||
proto.set_dtype(DT_STRING);
|
||||
TensorShape({}).AsProto(proto.mutable_tensor_shape());
|
||||
proto.add_string_val(s.data(), s.size());
|
||||
node_builder.Attr("dtype", DT_STRING).Attr("value", proto);
|
||||
return options.FinalizeBuilder(&node_builder);
|
||||
return Output(ret);
|
||||
}
|
||||
|
||||
DEFINE_CONST_VECTOR(string)
|
||||
DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());)
|
||||
|
||||
#undef DEFINE_CONST
|
||||
#undef DEFINE_CONST_IMPL
|
||||
#undef DEFINE_CONST_TENSOR
|
||||
#undef DEFINE_CONST_VECTOR
|
||||
#undef DEFINE_CONST_SCALAR
|
||||
|
||||
Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) {
|
||||
if (options.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
|
||||
options.op_registry());
|
||||
node_builder.Attr("dtype", t.dtype()).Attr("value", t);
|
||||
return options.FinalizeBuilder(&node_builder);
|
||||
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) {
|
||||
if (!inp.status().ok()) {
|
||||
scope.UpdateStatus(inp.status());
|
||||
return NodeBuilder::NodeOut(inp.node(), inp.index());
|
||||
}
|
||||
if (inp.node()) {
|
||||
return NodeBuilder::NodeOut(inp.node(), inp.index());
|
||||
}
|
||||
if (!inp.node_name().empty()) {
|
||||
return NodeBuilder::NodeOut(inp.node_name(), inp.index(), inp.data_type());
|
||||
}
|
||||
auto transformed = Input{
|
||||
Const(scope.NewSubScope("Const"), Input::Initializer(inp.tensor()))};
|
||||
return NodeBuilder::NodeOut{transformed.node(), transformed.index()};
|
||||
}
|
||||
|
||||
Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) {
|
||||
if (options.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
|
||||
options.op_registry());
|
||||
node_builder.Attr("dtype", proto.dtype()).Attr("value", proto);
|
||||
return options.FinalizeBuilder(&node_builder);
|
||||
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
|
||||
const InputList& inp) {
|
||||
std::vector<NodeBuilder::NodeOut> out;
|
||||
for (const auto& i : inp) {
|
||||
const auto node_out = AsNodeOut(scope, i);
|
||||
if (!scope.ok()) {
|
||||
return {};
|
||||
}
|
||||
out.push_back(node_out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,75 +13,53 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_
|
||||
#define TENSORFLOW_CC_OPS_CONST_OP_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
// If a shape is specified, you may either provide the same number of values,
|
||||
// or a single value and that value will be duplicated to fill out the Tensor.
|
||||
#define DECLARE_CONST(TYPE) \
|
||||
Node* Const(TYPE s, const GraphDefBuilder::Options& options); /* Scalar */ \
|
||||
Node* Const(gtl::ArraySlice<TYPE> v, \
|
||||
const GraphDefBuilder::Options& options); /* Vector */ \
|
||||
Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
|
||||
const GraphDefBuilder::Options& options); /* Tensor */ \
|
||||
inline Node* Const(std::initializer_list<TYPE> v, /* Vector using {...} */ \
|
||||
const GraphDefBuilder::Options& options) { \
|
||||
return Const(gtl::ArraySlice<TYPE>(v), options); \
|
||||
} \
|
||||
inline Node* Const(std::initializer_list<TYPE> t, /* Tensor using {...} */ \
|
||||
const TensorShape& shape, \
|
||||
const GraphDefBuilder::Options& options) { \
|
||||
return Const(gtl::ArraySlice<TYPE>(t), shape, options); \
|
||||
Output Const(const Scope& scope, const Input::Initializer& val);
|
||||
|
||||
template <typename T>
|
||||
Output Const(const Scope& scope, const Input::Initializer& val) {
|
||||
if (!scope.ok()) return Output();
|
||||
if (!val.status.ok()) {
|
||||
scope.UpdateStatus(val.status);
|
||||
return Output();
|
||||
}
|
||||
|
||||
DECLARE_CONST(Eigen::half);
|
||||
DECLARE_CONST(float);
|
||||
DECLARE_CONST(double);
|
||||
DECLARE_CONST(int32);
|
||||
DECLARE_CONST(uint8);
|
||||
DECLARE_CONST(int16);
|
||||
DECLARE_CONST(int8);
|
||||
DECLARE_CONST(complex64);
|
||||
DECLARE_CONST(complex128);
|
||||
DECLARE_CONST(int64);
|
||||
DECLARE_CONST(bool);
|
||||
|
||||
#undef DECLARE_CONST
|
||||
|
||||
// String
|
||||
Node* Const(StringPiece s, const GraphDefBuilder::Options& options);
|
||||
Node* Const(gtl::ArraySlice<string> v, const GraphDefBuilder::Options& options);
|
||||
Node* Const(gtl::ArraySlice<string> t, const TensorShape& shape,
|
||||
const GraphDefBuilder::Options& options);
|
||||
inline Node* Const(std::initializer_list<string> v,
|
||||
const GraphDefBuilder::Options& options) {
|
||||
return Const(gtl::ArraySlice<string>(v), options);
|
||||
}
|
||||
inline Node* Const(std::initializer_list<string> t, const TensorShape& shape,
|
||||
const GraphDefBuilder::Options& options) {
|
||||
return Const(gtl::ArraySlice<string>(t), shape, options);
|
||||
typedef typename Input::Initializer::RealType<T>::type DstT;
|
||||
if (val.tensor.NumElements() > 0) {
|
||||
// TODO(keveman): Implement the in-situ cast.
|
||||
scope.UpdateStatus(errors::Unimplemented(
|
||||
"Explict cast of a non-empty tensor not implemented yet"));
|
||||
return Output();
|
||||
}
|
||||
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
|
||||
return Const(scope, Input::Initializer(t));
|
||||
}
|
||||
|
||||
// A Tensor of any type.
|
||||
Node* Const(const Tensor& t, const GraphDefBuilder::Options& options);
|
||||
Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options);
|
||||
|
||||
template <class T>
|
||||
Node* EmptyConst(const GraphDefBuilder::Options& options) {
|
||||
return Const(gtl::ArraySlice<T>(), options);
|
||||
template <typename T>
|
||||
Output Const(const Scope& scope, const T& v, const TensorShape shape) {
|
||||
return Const(scope, Input::Initializer(v, shape));
|
||||
}
|
||||
|
||||
// TODO(josh11b): Support other types (e.g. quantized ints, float16).
|
||||
template <typename T>
|
||||
Output Const(const Scope& scope, const std::initializer_list<T>& v,
|
||||
const TensorShape shape) {
|
||||
return Const(scope, Input::Initializer(v, shape));
|
||||
}
|
||||
|
||||
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
|
||||
|
||||
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
|
||||
const InputList& inp);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_OPS_CONST_OP_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_
|
||||
|
128
tensorflow/cc/ops/const_op_test.cc
Normal file
128
tensorflow/cc/ops/const_op_test.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
|
||||
TensorShape shape) {
|
||||
EXPECT_TRUE(n->IsConstant());
|
||||
Tensor tensor;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
|
||||
DataType dtype;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape));
|
||||
}
|
||||
|
||||
void ExpectTypeAndShape(const Node* n, DataType expected_dtype,
|
||||
TensorShape expected_shape) {
|
||||
EXPECT_TRUE(n->IsConstant());
|
||||
Tensor tensor;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
|
||||
DataType dtype;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
|
||||
EXPECT_EQ(dtype, expected_dtype);
|
||||
EXPECT_EQ(expected_shape, TensorShape(tensor.shape()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(ConstOpTest, Basic) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = ops::Const(root, 42.0f);
|
||||
TF_EXPECT_OK(root.status());
|
||||
EXPECT_EQ(c.op().output_type(0), DT_FLOAT);
|
||||
ExpectNodeEqual<float>(c.node(), {42.0f}, {});
|
||||
}
|
||||
|
||||
TEST(ConstOpTest, MultiDim) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = ops::Const(root, {{2.0}, {3.0}});
|
||||
TF_CHECK_OK(root.status());
|
||||
EXPECT_EQ(c.op().output_type(0), DT_DOUBLE);
|
||||
ExpectNodeEqual<double>(c.node(), {2.0, 3.0}, {2, 1});
|
||||
}
|
||||
|
||||
TEST(ConstOpTest, Empty) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
auto c1 = ops::Const(root, {});
|
||||
TF_CHECK_OK(root.status());
|
||||
ExpectTypeAndShape(c1.node(), DT_FLOAT, {0});
|
||||
|
||||
auto c2 = ops::Const(root, {{}});
|
||||
TF_CHECK_OK(root.status());
|
||||
ExpectTypeAndShape(c2.node(), DT_FLOAT, {1, 0});
|
||||
|
||||
auto c3 = ops::Const(root, {{{}, {}}});
|
||||
TF_CHECK_OK(root.status());
|
||||
ExpectTypeAndShape(c3.node(), DT_FLOAT, {1, 2, 0});
|
||||
|
||||
auto c4 = ops::Const<int>(root, {{{}}});
|
||||
TF_CHECK_OK(root.status());
|
||||
ExpectTypeAndShape(c4.node(), DT_INT32, {1, 1, 0});
|
||||
|
||||
ops::Const(root, {{}, {{}}});
|
||||
EXPECT_FALSE(root.status().ok());
|
||||
}
|
||||
|
||||
TEST(ConstOpTest, WithExplicitShape) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = ops::Const(root, 42.0, {2, 2});
|
||||
TF_CHECK_OK(root.status());
|
||||
EXPECT_EQ(c.op().output_type(0), DT_DOUBLE);
|
||||
ExpectNodeEqual<double>(c.node(), {42.0, 42.0, 42.0, 42.0}, {2, 2});
|
||||
|
||||
auto d = ops::Const(root, {"1", "2", "3", "4", "5", "6"}, {2, 3});
|
||||
TF_CHECK_OK(root.status());
|
||||
EXPECT_EQ(d.op().output_type(0), DT_STRING);
|
||||
ExpectNodeEqual<string>(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3});
|
||||
}
|
||||
|
||||
TEST(ConstOpTest, InvalidInitializer) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
ops::Const(root, {{2.0}, {"df"}});
|
||||
EXPECT_FALSE(root.status().ok());
|
||||
}
|
||||
|
||||
TEST(ConstOpTest, Names) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = ops::Const(root, {{2.0}, {3.0}});
|
||||
EXPECT_EQ(c.node()->name(), "Const");
|
||||
auto c_1 = ops::Const(root, {{2.0}, {3.0}});
|
||||
EXPECT_EQ(c_1.node()->name(), "Const_1");
|
||||
|
||||
auto x = ops::Const(root.WithOpName("x"), 1);
|
||||
EXPECT_EQ(x.node()->name(), "x");
|
||||
auto x_1 = ops::Const(root.WithOpName("x"), 1);
|
||||
EXPECT_EQ(x_1.node()->name(), "x_1");
|
||||
|
||||
Scope child = root.NewSubScope("c");
|
||||
auto c_y = ops::Const(child.WithOpName("y"), 1);
|
||||
EXPECT_EQ(c_y.node()->name(), "c/y");
|
||||
auto c_y_1 = ops::Const(child.WithOpName("y"), 1);
|
||||
EXPECT_EQ(c_y_1.node()->name(), "c/y_1");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// #include this file to get access to the standard set of C++ graph
|
||||
// definition libraries.
|
||||
|
||||
#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_
|
||||
#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
|
||||
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/candidate_sampling_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/io_ops.h"
|
||||
@ -28,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/logging_ops.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/no_op.h"
|
||||
#include "tensorflow/cc/ops/parsing_ops.h"
|
||||
#include "tensorflow/cc/ops/random_ops.h"
|
||||
#include "tensorflow/cc/ops/sparse_ops.h"
|
||||
@ -36,4 +36,4 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/training_ops.h"
|
||||
#include "tensorflow/cc/ops/user_ops.h"
|
||||
|
||||
#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_
|
||||
|
@ -49,31 +49,33 @@ struct Options {
|
||||
GraphDef CreateGraphDef() {
|
||||
// TODO(jeff,opensource): This should really be a more interesting
|
||||
// computation. Maybe turn this into an mnist model instead?
|
||||
GraphDefBuilder b;
|
||||
Scope root = Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
// Store rows [3, 2] and [-1, 0] in row major format.
|
||||
Node* a = Const({3.f, 2.f, -1.f, 0.f}, {2, 2}, b.opts());
|
||||
|
||||
// x is from the feed.
|
||||
Node* x = Const({0.f}, {2, 1}, b.opts().WithName("x"));
|
||||
// a = [3 2; -1 0]
|
||||
auto a = Const(root, {{3.f, 2.f}, {-1.f, 0.f}});
|
||||
|
||||
// y = A * x
|
||||
Node* y = MatMul(a, x, b.opts().WithName("y"));
|
||||
// x = [1.0; 1.0]
|
||||
auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});
|
||||
|
||||
// y = a * x
|
||||
auto y = MatMul(root.WithOpName("y"), a, x);
|
||||
|
||||
// y2 = y.^2
|
||||
Node* y2 = Square(y, b.opts());
|
||||
auto y2 = Square(root, y);
|
||||
|
||||
// y2_sum = sum(y2)
|
||||
Node* y2_sum = Sum(y2, Const(0, b.opts()), b.opts());
|
||||
auto y2_sum = Sum(root, y2, 0);
|
||||
|
||||
// y_norm = sqrt(y2_sum)
|
||||
Node* y_norm = Sqrt(y2_sum, b.opts());
|
||||
auto y_norm = Sqrt(root, y2_sum);
|
||||
|
||||
// y_normalized = y ./ y_norm
|
||||
Div(y, y_norm, b.opts().WithName("y_normalized"));
|
||||
Div(root.WithOpName("y_normalized"), y, y_norm);
|
||||
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(b.ToGraphDef(&def));
|
||||
TF_CHECK_OK(root.ToGraphDef(&def));
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
|
@ -42,13 +42,12 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
|
||||
Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT);
|
||||
test::FillFn<float>(&test_tensor2, [](int) -> float { return 3.0; });
|
||||
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* node1 = ops::Const(test_tensor1, b.opts());
|
||||
Node* node2 = ops::Const(test_tensor2, b.opts());
|
||||
const string result_name = ops::MatMul(node1, node2, b.opts())->name();
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Output m = ops::MatMul(root, test_tensor1, test_tensor2);
|
||||
const string result_name = m.node()->name();
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
string graph_def_serialized;
|
||||
graph_def.SerializeToString(&graph_def_serialized);
|
||||
TF_ASSERT_OK(
|
||||
|
@ -1449,6 +1449,7 @@ tf_cc_tests(
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
@ -1752,6 +1753,7 @@ tf_cc_test_gpu(
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
],
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
|
||||
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -35,9 +34,9 @@ class GpuStreamUtilTest : public OpsTestBase {
|
||||
};
|
||||
|
||||
TEST_F(GpuStreamUtilTest, BogusOpts) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(b.ToGraph(&g));
|
||||
root.ToGraph(&g);
|
||||
std::unordered_map<int, int> node_to_stream_id;
|
||||
gpu_stream_util::AssignStreamsOpts opts;
|
||||
Status status;
|
||||
@ -55,9 +54,9 @@ TEST_F(GpuStreamUtilTest, BogusOpts) {
|
||||
}
|
||||
|
||||
TEST_F(GpuStreamUtilTest, EmptyGraph) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(b.ToGraph(&g));
|
||||
root.ToGraph(&g);
|
||||
std::unordered_map<int, int> node_to_stream_id;
|
||||
gpu_stream_util::AssignStreamsOpts opts;
|
||||
TF_ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
|
||||
@ -65,11 +64,10 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) {
|
||||
}
|
||||
|
||||
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
|
||||
ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
ops::MatMul(root, {}, {});
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(b.ToGraph(&g));
|
||||
TF_ASSERT_OK(root.ToGraph(&g));
|
||||
|
||||
std::unordered_map<int, int> node_to_stream_id;
|
||||
gpu_stream_util::AssignStreamsOpts opts;
|
||||
@ -85,11 +83,10 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
|
||||
}
|
||||
|
||||
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
|
||||
ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
ops::MatMul(root, {}, {});
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(b.ToGraph(&g));
|
||||
TF_ASSERT_OK(root.ToGraph(&g));
|
||||
|
||||
std::unordered_map<int, int> node_to_stream_id;
|
||||
gpu_stream_util::AssignStreamsOpts opts;
|
||||
@ -107,14 +104,13 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
||||
}
|
||||
|
||||
TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
ops::_Recv(DT_FLOAT, "input", "/cpu:0", 0, "/gpu:0",
|
||||
b.opts().WithName("input"));
|
||||
auto n = ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
|
||||
ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
|
||||
ops::_Send(n, "output", "/gpu:0", 0, "/cpu:0", b.opts().WithName("output"));
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
|
||||
"/gpu:0");
|
||||
ops::Output n = ops::MatMul(root, {}, {});
|
||||
ops::_Send(root.WithOpName("output"), n, "output", "/gpu:0", 0, "/cpu:0");
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(b.ToGraph(&g));
|
||||
TF_ASSERT_OK(root.ToGraph(&g));
|
||||
|
||||
// Perform stream assignment using a large number of streams, but with
|
||||
// op types constrained to specific streams.
|
||||
|
@ -133,28 +133,41 @@ REGISTER_OP("Input").Output("o: float");
|
||||
REGISTER_OP("BoolInput").Output("o: bool");
|
||||
REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
|
||||
|
||||
Node* Input(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("Input", opts);
|
||||
ops::Output ConstructOp(const Scope& scope, const string& op_type,
|
||||
const gtl::ArraySlice<ops::Input>& inputs) {
|
||||
if (!scope.ok()) return ops::Output();
|
||||
const string unique_name = scope.GetUniqueNameForOp(op_type);
|
||||
auto builder = NodeBuilder(unique_name, op_type);
|
||||
for (auto const& input : inputs) {
|
||||
builder.Input(ops::NodeOut(input.node(), input.index()));
|
||||
}
|
||||
scope.UpdateBuilder(&builder);
|
||||
Node* ret;
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
if (!scope.ok()) return ops::Output();
|
||||
return ops::Output(ret);
|
||||
}
|
||||
|
||||
Node* BoolInput(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("BoolInput", opts);
|
||||
ops::Output Input(const Scope& scope) {
|
||||
return ConstructOp(scope, "Input", {});
|
||||
}
|
||||
|
||||
Node* Combine(ops::NodeOut a, ops::NodeOut b,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
return ops::BinaryOp("Combine", a, b, opts);
|
||||
ops::Output BoolInput(const Scope& scope) {
|
||||
return ConstructOp(scope, "BoolInput", {});
|
||||
}
|
||||
|
||||
ops::Output Combine(const Scope& scope, ops::Input a, ops::Input b) {
|
||||
return ConstructOp(scope, "Combine", {a, b});
|
||||
}
|
||||
|
||||
class GraphPartitionTest : public ::testing::Test {
|
||||
protected:
|
||||
GraphPartitionTest()
|
||||
: in_(GraphDefBuilder::kFailImmediately),
|
||||
builder_a_(GraphDefBuilder::kFailImmediately),
|
||||
builder_b_(GraphDefBuilder::kFailImmediately),
|
||||
a_opts_(builder_a_.opts().WithDevice("/job:a/replica:0/task:0/cpu:0")),
|
||||
b_opts_(builder_b_.opts().WithDevice("/job:a/replica:0/task:0/cpu:1")) {
|
||||
}
|
||||
: in_(Scope::NewRootScope().ExitOnError()),
|
||||
scope_a_(Scope::NewRootScope().ExitOnError().WithDevice(
|
||||
"/job:a/replica:0/task:0/cpu:0")),
|
||||
scope_b_(Scope::NewRootScope().ExitOnError().WithDevice(
|
||||
"/job:a/replica:0/task:0/cpu:1")) {}
|
||||
|
||||
const GraphDef& ToGraphDef() {
|
||||
in_.ToGraphDef(&in_graph_def_);
|
||||
@ -163,187 +176,187 @@ class GraphPartitionTest : public ::testing::Test {
|
||||
|
||||
void ExpectMatchA() {
|
||||
GraphDef graph_def;
|
||||
builder_a_.ToGraphDef(&graph_def);
|
||||
scope_a_.ToGraphDef(&graph_def);
|
||||
string a = "/job:a/replica:0/task:0/cpu:0";
|
||||
TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]);
|
||||
}
|
||||
|
||||
void ExpectMatchB() {
|
||||
GraphDef graph_def;
|
||||
builder_b_.ToGraphDef(&graph_def);
|
||||
scope_b_.ToGraphDef(&graph_def);
|
||||
string b = "/job:a/replica:0/task:0/cpu:1";
|
||||
TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]);
|
||||
}
|
||||
|
||||
GraphDefBuilder in_;
|
||||
Scope in_;
|
||||
GraphDef in_graph_def_;
|
||||
GraphDefBuilder builder_a_;
|
||||
GraphDefBuilder builder_b_;
|
||||
GraphDefBuilder::Options a_opts_;
|
||||
GraphDefBuilder::Options b_opts_;
|
||||
Scope scope_a_;
|
||||
Scope scope_b_;
|
||||
std::unordered_map<string, GraphDef> partitions_;
|
||||
};
|
||||
|
||||
TEST_F(GraphPartitionTest, SingleDevice) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = Input(in_.opts().WithName("A1"));
|
||||
Combine(a1, a1, in_.opts().WithName("A2"));
|
||||
auto a1 = Input(in_.WithOpName("A1"));
|
||||
Combine(in_.WithOpName("A2"), a1, a1);
|
||||
|
||||
Partition(ToGraphDef(), &partitions_);
|
||||
EXPECT_EQ(1, partitions_.size());
|
||||
|
||||
a1 = Input(a_opts_.WithName("A1"));
|
||||
Combine(a1, a1, a_opts_.WithName("A2"));
|
||||
a1 = Input(scope_a_.WithOpName("A1"));
|
||||
Combine(scope_a_.WithOpName("A2"), a1, a1);
|
||||
ExpectMatchA();
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceData) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = Input(in_.opts().WithName("A1"));
|
||||
Node* b1 = Input(in_.opts().WithName("B1"));
|
||||
Combine(a1, b1, in_.opts().WithName("B2"));
|
||||
auto a1 = Input(in_.WithOpName("A1"));
|
||||
auto b1 = Input(in_.WithOpName("B1"));
|
||||
Combine(in_.WithOpName("B2"), a1, b1);
|
||||
|
||||
Partition(ToGraphDef(), &partitions_);
|
||||
EXPECT_EQ(2, partitions_.size());
|
||||
|
||||
string a = "/job:a/replica:0/task:0/cpu:0";
|
||||
string b = "/job:a/replica:0/task:0/cpu:1";
|
||||
a1 = Input(a_opts_.WithName("A1"));
|
||||
_Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
|
||||
a1 = Input(scope_a_.WithOpName("A1"));
|
||||
_Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
|
||||
ExpectMatchA();
|
||||
|
||||
b1 = Input(b_opts_.WithName("B1"));
|
||||
Node* recv =
|
||||
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
|
||||
Combine(recv, b1, b_opts_.WithName("B2"));
|
||||
b1 = Input(scope_b_.WithOpName("B1"));
|
||||
auto recv =
|
||||
_Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
|
||||
Combine(scope_b_.WithOpName("B2"), recv, b1);
|
||||
ExpectMatchB();
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceControl) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = Input(in_.opts().WithName("A1"));
|
||||
Node* b1 = Input(in_.opts().WithName("B1"));
|
||||
Combine(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
|
||||
auto a1 = Input(in_.WithOpName("A1"));
|
||||
auto b1 = Input(in_.WithOpName("B1"));
|
||||
Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
|
||||
|
||||
Partition(ToGraphDef(), &partitions_);
|
||||
EXPECT_EQ(2, partitions_.size());
|
||||
|
||||
string a = "/job:a/replica:0/task:0/cpu:0";
|
||||
string b = "/job:a/replica:0/task:0/cpu:1";
|
||||
a1 = Input(a_opts_.WithName("A1"));
|
||||
Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
|
||||
_Send(c, "edge_3_A1", a, 82, b, a_opts_.WithName("A1/_1"));
|
||||
a1 = Input(scope_a_.WithOpName("A1"));
|
||||
auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
|
||||
_Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
|
||||
ExpectMatchA();
|
||||
|
||||
Node* recv =
|
||||
_Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2"));
|
||||
Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
|
||||
b1 = Input(b_opts_.WithName("B1"));
|
||||
Combine(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
|
||||
auto recv =
|
||||
_Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
|
||||
auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
|
||||
b1 = Input(scope_b_.WithOpName("B1"));
|
||||
Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
|
||||
ExpectMatchB();
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = Input(in_.opts().WithName("A1"));
|
||||
Node* b1 = Input(in_.opts().WithName("B1"));
|
||||
Combine(a1, b1, in_.opts().WithName("B2"));
|
||||
Combine(a1, a1, in_.opts().WithName("B3"));
|
||||
auto a1 = Input(in_.WithOpName("A1"));
|
||||
auto b1 = Input(in_.WithOpName("B1"));
|
||||
Combine(in_.WithOpName("B2"), a1, b1);
|
||||
Combine(in_.WithOpName("B3"), a1, a1);
|
||||
|
||||
Partition(ToGraphDef(), &partitions_);
|
||||
EXPECT_EQ(2, partitions_.size());
|
||||
|
||||
string a = "/job:a/replica:0/task:0/cpu:0";
|
||||
string b = "/job:a/replica:0/task:0/cpu:1";
|
||||
a1 = Input(a_opts_.WithName("A1"));
|
||||
_Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
|
||||
a1 = Input(scope_a_.WithOpName("A1"));
|
||||
_Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
|
||||
ExpectMatchA();
|
||||
|
||||
Node* recv =
|
||||
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
|
||||
b1 = Input(b_opts_.WithName("B1"));
|
||||
Combine(recv, b1, b_opts_.WithName("B2"));
|
||||
Combine(recv, recv, b_opts_.WithName("B3"));
|
||||
auto recv =
|
||||
_Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
|
||||
b1 = Input(scope_b_.WithOpName("B1"));
|
||||
Combine(scope_b_.WithOpName("B2"), recv, b1);
|
||||
Combine(scope_b_.WithOpName("B3"), recv, recv);
|
||||
ExpectMatchB();
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = Input(in_.opts().WithName("A1"));
|
||||
Node* b1 = Input(in_.opts().WithName("B1"));
|
||||
Combine(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
|
||||
Input(in_.opts().WithName("B3").WithControlInput(a1));
|
||||
auto a1 = Input(in_.WithOpName("A1"));
|
||||
auto b1 = Input(in_.WithOpName("B1"));
|
||||
Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
|
||||
Input(in_.WithOpName("B3").WithControlDependencies(a1));
|
||||
|
||||
Partition(ToGraphDef(), &partitions_);
|
||||
EXPECT_EQ(2, partitions_.size());
|
||||
|
||||
string a = "/job:a/replica:0/task:0/cpu:0";
|
||||
string b = "/job:a/replica:0/task:0/cpu:1";
|
||||
a1 = Input(a_opts_.WithName("A1"));
|
||||
Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
|
||||
_Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
|
||||
a1 = Input(scope_a_.WithOpName("A1"));
|
||||
auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
|
||||
_Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
|
||||
ExpectMatchA();
|
||||
|
||||
Node* recv =
|
||||
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
|
||||
Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
|
||||
b1 = Input(b_opts_.WithName("B1"));
|
||||
Combine(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
|
||||
Input(b_opts_.WithName("B3").WithControlInput(id));
|
||||
auto recv =
|
||||
_Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
|
||||
auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
|
||||
b1 = Input(scope_b_.WithOpName("B1"));
|
||||
Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
|
||||
Input(scope_b_.WithOpName("B3").WithControlDependencies(id));
|
||||
ExpectMatchB();
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = Input(in_.opts().WithName("A1"));
|
||||
Node* b1 = Input(in_.opts().WithName("B1"));
|
||||
Combine(a1, b1, in_.opts().WithName("B2"));
|
||||
Input(in_.opts().WithName("B3").WithControlInput(a1));
|
||||
auto a1 = Input(in_.WithOpName("A1"));
|
||||
auto b1 = Input(in_.WithOpName("B1"));
|
||||
Combine(in_.WithOpName("B2"), a1, b1);
|
||||
Input(in_.WithOpName("B3").WithControlDependencies(a1));
|
||||
|
||||
Partition(ToGraphDef(), &partitions_);
|
||||
EXPECT_EQ(2, partitions_.size());
|
||||
|
||||
string a = "/job:a/replica:0/task:0/cpu:0";
|
||||
string b = "/job:a/replica:0/task:0/cpu:1";
|
||||
a1 = Input(a_opts_.WithName("A1"));
|
||||
Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
|
||||
a1 = Input(scope_a_.WithOpName("A1"));
|
||||
auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
|
||||
// NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
|
||||
// use A1/_0 -> A1/_4 as the control as a minor optimization.
|
||||
_Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
|
||||
_Send(a1, "edge_2_A1", a, 82, b, a_opts_.WithName("A1/_4"));
|
||||
_Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
|
||||
_Send(scope_a_.WithOpName("A1/_4"), a1, "edge_2_A1", a, 82, b);
|
||||
ExpectMatchA();
|
||||
|
||||
Node* recv1 =
|
||||
_Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
|
||||
Node* id1 = Identity(recv1, b_opts_.WithName("A1/_3"));
|
||||
Node* recv2 =
|
||||
_Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5"));
|
||||
b1 = Input(b_opts_.WithName("B1"));
|
||||
Combine(recv2, b1, b_opts_.WithName("B2"));
|
||||
Input(b_opts_.WithName("B3").WithControlInput(id1));
|
||||
auto recv1 =
|
||||
_Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
|
||||
auto id1 = Identity(scope_b_.WithOpName("A1/_3"), recv1);
|
||||
auto recv2 =
|
||||
_Recv(scope_b_.WithOpName("A1/_5"), DT_FLOAT, "edge_2_A1", a, 82, b);
|
||||
b1 = Input(scope_b_.WithOpName("B1"));
|
||||
Combine(scope_b_.WithOpName("B2"), recv2, b1);
|
||||
Input(scope_b_.WithOpName("B3").WithControlDependencies(id1));
|
||||
ExpectMatchB();
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceLoop) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = BoolInput(in_.opts().WithName("A1"));
|
||||
Node* a2 = Enter(a1, "foo", in_.opts().WithName("A2"));
|
||||
Node* a3 = Merge({a2, {"A5", 0, DT_BOOL}}, in_.opts().WithName("A3"));
|
||||
LoopCond(a3, in_.opts().WithName("A4"));
|
||||
Node* b1 = Identity(a3, in_.opts().WithName("B1"));
|
||||
NextIteration(b1, in_.opts().WithName("A5"));
|
||||
auto a1 = BoolInput(in_.WithOpName("A1"));
|
||||
auto a2 = Enter(in_.WithOpName("A2"), a1, "foo");
|
||||
auto a3 =
|
||||
Merge(in_.WithOpName("A3"), {a2, ops::Input("A5", 0, DT_BOOL)}).output;
|
||||
LoopCond(in_.WithOpName("A4"), a3);
|
||||
auto b1 = Identity(in_.WithOpName("B1"), a3);
|
||||
NextIteration(in_.WithOpName("A5"), b1);
|
||||
|
||||
CheckLoopConstruction(ToGraphDef());
|
||||
}
|
||||
|
||||
TEST_F(GraphPartitionTest, CrossDeviceLoop1) {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Node* a1 = BoolInput(in_.opts().WithName("A1"));
|
||||
Node* a2 = Enter(a1, "foo", in_.opts().WithName("B2"));
|
||||
Node* a3 = Merge({a2, {"B5", 0, DT_BOOL}}, in_.opts().WithName("A3"));
|
||||
LoopCond(a3, in_.opts().WithName("A4"));
|
||||
Node* b1 = Identity(a3, in_.opts().WithName("B1"));
|
||||
NextIteration(b1, in_.opts().WithName("B5"));
|
||||
auto a1 = BoolInput(in_.WithOpName("A1"));
|
||||
auto a2 = Enter(in_.WithOpName("B2"), a1, "foo");
|
||||
auto a3 =
|
||||
Merge(in_.WithOpName("A3"), {a2, ops::Input("B5", 0, DT_BOOL)}).output;
|
||||
LoopCond(in_.WithOpName("A4"), a3);
|
||||
auto b1 = Identity(in_.WithOpName("B1"), a3);
|
||||
NextIteration(in_.WithOpName("B5"), b1);
|
||||
|
||||
std::unordered_map<string, GraphDef> partitions;
|
||||
Partition(ToGraphDef(), &partitions);
|
||||
|
@ -91,14 +91,14 @@ struct ImmutableConstantOpTest {};
|
||||
TEST(ImmutableConstantOpTest, Simple) {
|
||||
const TensorShape kTestTensorShape({4, 1});
|
||||
const TensorShape kTestTensorShapeT({1, 4});
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* node1 =
|
||||
ops::ImmutableConst(DT_FLOAT, kTestTensorShape, "test://2", b.opts());
|
||||
Node* node2 =
|
||||
ops::ImmutableConst(DT_FLOAT, kTestTensorShapeT, "test://3", b.opts());
|
||||
Node* result = ops::MatMul(node1, node2, b.opts());
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto node1 =
|
||||
ops::ImmutableConst(root, DT_FLOAT, kTestTensorShape, "test://2");
|
||||
auto node2 =
|
||||
ops::ImmutableConst(root, DT_FLOAT, kTestTensorShapeT, "test://3");
|
||||
auto result = ops::MatMul(root, node1, node2);
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
session_options.config.mutable_graph_options()
|
||||
@ -108,7 +108,7 @@ TEST(ImmutableConstantOpTest, Simple) {
|
||||
ASSERT_TRUE(session != nullptr) << "Failed to create session";
|
||||
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run({}, {result->name() + ":0"}, {}, &outputs));
|
||||
TF_ASSERT_OK(session->Run({}, {result.node()->name() + ":0"}, {}, &outputs));
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f);
|
||||
EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f);
|
||||
@ -122,14 +122,14 @@ TEST(ImmutableConstantOpTest, Simple) {
|
||||
TEST(ImmutableConstantOpTest, ExecutionError) {
|
||||
const TensorShape kBadTensorShape({40, 100});
|
||||
const TensorShape kTestTensorShapeT({1, 4});
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* node1 =
|
||||
ops::ImmutableConst(DT_FLOAT, kBadTensorShape, "test://2", b.opts());
|
||||
Node* node2 =
|
||||
ops::ImmutableConst(DT_FLOAT, kTestTensorShapeT, "test://3", b.opts());
|
||||
Node* result = ops::MatMul(node1, node2, b.opts());
|
||||
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto node1 = ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test://2");
|
||||
auto node2 =
|
||||
ops::ImmutableConst(root, DT_FLOAT, kTestTensorShapeT, "test://3");
|
||||
auto result = ops::MatMul(root, node1, node2);
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
std::unique_ptr<Session> session(NewSession(session_options));
|
||||
@ -137,8 +137,9 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
|
||||
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
|
||||
std::vector<Tensor> outputs;
|
||||
// Check that the run returned error.
|
||||
EXPECT_EQ(session->Run({}, {result->name() + ":0"}, {}, &outputs).code(),
|
||||
error::INTERNAL);
|
||||
EXPECT_EQ(
|
||||
session->Run({}, {result.node()->name() + ":0"}, {}, &outputs).code(),
|
||||
error::INTERNAL);
|
||||
}
|
||||
|
||||
Status CreateTempFile(Env* env, float value, uint64 size, string* filename) {
|
||||
@ -158,19 +159,18 @@ Status CreateTempFile(Env* env, float value, uint64 size, string* filename) {
|
||||
TEST(ImmutableConstantOpTest, FromFile) {
|
||||
const TensorShape kFileTensorShape({1000, 1});
|
||||
Env* env = Env::Default();
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
string two_file, three_file;
|
||||
TF_ASSERT_OK(CreateTempFile(env, 2.0f, 1000, &two_file));
|
||||
TF_ASSERT_OK(CreateTempFile(env, 3.0f, 1000, &three_file));
|
||||
Node* node1 =
|
||||
ops::ImmutableConst(DT_FLOAT, kFileTensorShape, two_file, b.opts());
|
||||
Node* node2 =
|
||||
ops::ImmutableConst(DT_FLOAT, kFileTensorShape, three_file, b.opts());
|
||||
Node* result =
|
||||
ops::MatMul(node1, node2, b.opts().WithAttr("transpose_b", true));
|
||||
auto node1 = ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, two_file);
|
||||
auto node2 =
|
||||
ops::ImmutableConst(root, DT_FLOAT, kFileTensorShape, three_file);
|
||||
auto result = ops::MatMul(root, node1, node2, ops::MatMul::TransposeB(true));
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
SessionOptions session_options;
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
@ -179,7 +179,7 @@ TEST(ImmutableConstantOpTest, FromFile) {
|
||||
ASSERT_TRUE(session != nullptr) << "Failed to create session";
|
||||
TF_ASSERT_OK(session->Create(graph_def)) << "Can't create test graph";
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run({}, {result->name() + ":0"}, {}, &outputs));
|
||||
TF_ASSERT_OK(session->Run({}, {result.node()->name() + ":0"}, {}, &outputs));
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f);
|
||||
EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f);
|
||||
|
@ -1047,7 +1047,7 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
int depth, int kernel_rows, int kernel_cols,
|
||||
int stride, Padding padding, int num_threads,
|
||||
bool use_gpu, const string& label) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
int64 out_height, out_width, pad_rows, pad_cols;
|
||||
TF_CHECK_OK(GetWindowedOutputSize(rows, kernel_rows, stride, padding,
|
||||
@ -1057,27 +1057,25 @@ static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
|
||||
|
||||
Tensor input_data(DT_FLOAT, TensorShape({batch_size, rows, cols, depth}));
|
||||
input_data.flat<float>().setRandom();
|
||||
Node* input_data_node = ops::Const(input_data, b.opts());
|
||||
|
||||
Tensor output_data(DT_FLOAT,
|
||||
TensorShape({batch_size, out_height, out_width, depth}));
|
||||
output_data.flat<float>().setRandom();
|
||||
Node* output_data_node = ops::Const(output_data, b.opts());
|
||||
|
||||
Tensor output_diff(DT_FLOAT,
|
||||
TensorShape({batch_size, out_height, out_width, depth}));
|
||||
output_diff.flat<float>().setRandom();
|
||||
Node* output_diff_node = ops::Const(output_diff, b.opts());
|
||||
|
||||
CHECK_EQ(kernel_rows, kernel_cols);
|
||||
ops::MaxPoolGrad(input_data_node, output_data_node, output_diff_node,
|
||||
ops::MaxPoolGrad(root, input_data, output_data, output_diff,
|
||||
{1, kernel_rows, kernel_cols, 1} /* ksize */,
|
||||
{1, stride, stride, 1} /* stride */,
|
||||
padding == VALID ? "VALID" : "SAME", b.opts());
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
TF_CHECK_OK(b.ToGraph(g));
|
||||
padding == VALID ? "VALID" : "SAME");
|
||||
TF_CHECK_OK(root.status());
|
||||
Graph g(OpRegistry::Global());
|
||||
root.ToGraph(&g);
|
||||
string device = use_gpu ? "gpu" : "cpu";
|
||||
test::Benchmark(device, g).Run(iters);
|
||||
test::Benchmark(device, &g).Run(iters);
|
||||
|
||||
testing::ItemsProcessed(batch_size * rows * cols * depth * iters);
|
||||
testing::SetLabel(label);
|
||||
|
@ -669,12 +669,9 @@ static void BM_LargeTensorWrite(int iters, int num_elements) {
|
||||
// Builds the graph.
|
||||
const string temp_filename =
|
||||
io::JoinPath(testing::TmpDir(), "benchmark_checkpoint");
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* filename = ops::Const(test::AsScalar<string>(temp_filename), b.opts());
|
||||
Node* tensor_names =
|
||||
ops::Const(test::AsTensor<string>({"my_tensor"}), b.opts());
|
||||
Node* tensors = ops::Const(tensor, b.opts());
|
||||
ops::Save(filename, tensor_names, {tensors}, b.opts());
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
const string tensor_name = "my_tensor";
|
||||
ops::Save(root, temp_filename, {tensor_name}, {{tensor}});
|
||||
|
||||
// Disables optimizations.
|
||||
SessionOptions session_options;
|
||||
@ -682,13 +679,14 @@ static void BM_LargeTensorWrite(int iters, int num_elements) {
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(tensorflow::OptimizerOptions_Level_L0);
|
||||
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
TF_CHECK_OK(b.ToGraph(g));
|
||||
TF_CHECK_OK(root.status());
|
||||
Graph g(OpRegistry::Global());
|
||||
root.ToGraph(&g);
|
||||
VLOG(1) << "Save op's output path: " << temp_filename;
|
||||
VLOG(1) << "# nodes in Graph: " << g->num_nodes();
|
||||
VLOG(1) << "# nodes in Graph: " << g.num_nodes();
|
||||
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g, &session_options).Run(iters);
|
||||
test::Benchmark("cpu", &g, &session_options).Run(iters);
|
||||
}
|
||||
BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);
|
||||
|
||||
|
@ -87,50 +87,44 @@ Status ReadTensorFromImageFile(string file_name, const int input_height,
|
||||
const int input_width, const float input_mean,
|
||||
const float input_std,
|
||||
std::vector<Tensor>* out_tensors) {
|
||||
tensorflow::GraphDefBuilder b;
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
string input_name = "file_reader";
|
||||
string output_name = "normalized";
|
||||
tensorflow::Node* file_reader =
|
||||
tensorflow::ops::ReadFile(tensorflow::ops::Const(file_name, b.opts()),
|
||||
b.opts().WithName(input_name));
|
||||
auto file_reader = ReadFile(root.WithOpName(input_name), file_name);
|
||||
// Now try to figure out what kind of file it is and decode it.
|
||||
const int wanted_channels = 3;
|
||||
tensorflow::Node* image_reader;
|
||||
Output image_reader;
|
||||
if (tensorflow::StringPiece(file_name).ends_with(".png")) {
|
||||
image_reader = tensorflow::ops::DecodePng(
|
||||
file_reader,
|
||||
b.opts().WithAttr("channels", wanted_channels).WithName("png_reader"));
|
||||
image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
|
||||
DecodePng::Channels(wanted_channels));
|
||||
} else {
|
||||
// Assume if it's not a PNG then it must be a JPEG.
|
||||
image_reader = tensorflow::ops::DecodeJpeg(
|
||||
file_reader,
|
||||
b.opts().WithAttr("channels", wanted_channels).WithName("jpeg_reader"));
|
||||
image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
|
||||
DecodeJpeg::Channels(wanted_channels));
|
||||
}
|
||||
// Now cast the image data to float so we can do normal math on it.
|
||||
tensorflow::Node* float_caster = tensorflow::ops::Cast(
|
||||
image_reader, tensorflow::DT_FLOAT, b.opts().WithName("float_caster"));
|
||||
auto float_caster =
|
||||
Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
|
||||
// The convention for image ops in TensorFlow is that all images are expected
|
||||
// to be in batches, so that they're four-dimensional arrays with indices of
|
||||
// [batch, height, width, channel]. Because we only have a single image, we
|
||||
// have to add a batch dimension of 1 to the start with ExpandDims().
|
||||
tensorflow::Node* dims_expander = tensorflow::ops::ExpandDims(
|
||||
float_caster, tensorflow::ops::Const(0, b.opts()), b.opts());
|
||||
auto dims_expander = ExpandDims(root, float_caster, 0);
|
||||
// Bilinearly resize the image to fit the required dimensions.
|
||||
tensorflow::Node* resized = tensorflow::ops::ResizeBilinear(
|
||||
dims_expander, tensorflow::ops::Const({input_height, input_width},
|
||||
b.opts().WithName("size")),
|
||||
b.opts());
|
||||
auto resized = ResizeBilinear(
|
||||
root, dims_expander,
|
||||
Const(root.WithOpName("size"), {input_height, input_width}));
|
||||
// Subtract the mean and divide by the scale.
|
||||
tensorflow::ops::Div(
|
||||
tensorflow::ops::Sub(
|
||||
resized, tensorflow::ops::Const({input_mean}, b.opts()), b.opts()),
|
||||
tensorflow::ops::Const({input_std}, b.opts()),
|
||||
b.opts().WithName(output_name));
|
||||
Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
|
||||
{input_std});
|
||||
|
||||
// This runs the GraphDef network definition that we've just constructed, and
|
||||
// returns the results in the output tensor.
|
||||
tensorflow::GraphDef graph;
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
|
||||
TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
|
||||
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
TF_RETURN_IF_ERROR(session->Create(graph));
|
||||
@ -161,15 +155,16 @@ Status LoadGraph(string graph_file_name,
|
||||
// their positions in the tensor, which correspond to categories.
|
||||
Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
|
||||
Tensor* indices, Tensor* scores) {
|
||||
tensorflow::GraphDefBuilder b;
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
string output_name = "top_k";
|
||||
tensorflow::ops::TopKV2(tensorflow::ops::Const(outputs[0], b.opts()),
|
||||
tensorflow::ops::Const(how_many_labels, b.opts()),
|
||||
b.opts().WithName(output_name));
|
||||
TopKV2(root.WithOpName(output_name), outputs[0], how_many_labels);
|
||||
// This runs the GraphDef network definition that we've just constructed, and
|
||||
// returns the results in the output tensors.
|
||||
tensorflow::GraphDef graph;
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph));
|
||||
TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
|
||||
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
TF_RETURN_IF_ERROR(session->Create(graph));
|
||||
|
@ -175,6 +175,11 @@ def tf_gen_op_wrappers_cc(name,
|
||||
other_srcs=[],
|
||||
other_hdrs=[],
|
||||
pkg="",
|
||||
deps=[
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:const_op",
|
||||
],
|
||||
op_gen="//tensorflow/cc:cc_op_gen_main"):
|
||||
subsrcs = other_srcs
|
||||
subhdrs = other_hdrs
|
||||
@ -186,7 +191,12 @@ def tf_gen_op_wrappers_cc(name,
|
||||
native.cc_library(name=name,
|
||||
srcs=subsrcs,
|
||||
hdrs=subhdrs,
|
||||
deps=["//tensorflow/core:core_cpu"],
|
||||
deps=deps + [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
copts=tf_copts(),
|
||||
alwayslink=1,)
|
||||
|
||||
|
@ -39,16 +39,15 @@ TEST(BenchmarkModelTest, InitializeAndRun) {
|
||||
Tensor constant_tensor(DT_FLOAT, constant_shape);
|
||||
test::FillFn<float>(&constant_tensor, [](int) -> float { return 3.0; });
|
||||
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* placeholder =
|
||||
ops::Placeholder(DT_FLOAT, b.opts().WithAttr("shape", input_shape));
|
||||
const string input_name = placeholder->name();
|
||||
Node* constant = ops::Const(constant_tensor, b.opts());
|
||||
const string output_name =
|
||||
ops::MatMul(placeholder, constant, b.opts())->name();
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto placeholder =
|
||||
ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input_shape));
|
||||
const string input_name = placeholder.node()->name();
|
||||
auto m = ops::MatMul(root, placeholder, constant_tensor);
|
||||
const string output_name = m.node()->name();
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
string graph_def_serialized;
|
||||
graph_def.SerializeToString(&graph_def_serialized);
|
||||
TF_ASSERT_OK(
|
||||
|
Loading…
Reference in New Issue
Block a user