Make function instantiation use std::vector<NodeDef> instead of GraphDef
It's about to turn into std::vector<NodeInfoPtr>; this change gets us partway there. RELNOTES: n/a PiperOrigin-RevId: 157771141
This commit is contained in:
parent
2e44be35dc
commit
8032e1f75d
@ -1231,7 +1231,7 @@ Status FunctionDefToBodyHelper(
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
|
||||
Status s = ConvertNodeDefsToGraph(opts, result.nodes, graph);
|
||||
if (!s.ok()) {
|
||||
delete graph;
|
||||
} else {
|
||||
|
@ -93,7 +93,7 @@ class FunctionTest : public ::testing::Test {
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g));
|
||||
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g));
|
||||
|
||||
const int version = g->versions().producer();
|
||||
LocalExecutorParams params;
|
||||
@ -949,7 +949,7 @@ GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
opts.expect_device_spec = false;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get()));
|
||||
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
|
||||
pass(g.get());
|
||||
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
|
||||
CopyGraph(*g, g1.get());
|
||||
|
@ -140,7 +140,7 @@ class FunctionInstantiationHelper {
|
||||
FunctionInstantiationHelper(GetFunctionSignature get_function,
|
||||
InstantiationResult* result)
|
||||
: get_function_(std ::move(get_function)), result_(*result) {
|
||||
result_.gdef.Clear();
|
||||
result_.nodes.clear();
|
||||
}
|
||||
|
||||
// Builds index for nodes that can be used as node's input arguments.
|
||||
@ -151,15 +151,14 @@ class FunctionInstantiationHelper {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
|
||||
CHECK_GE(dtypes.size(), size_t{1});
|
||||
GraphDef* gdef = &result_.gdef;
|
||||
int arg_index = gdef->node_size();
|
||||
int arg_index = result_.nodes.size();
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
|
||||
// Creates dtypes.size() nodes in the gdef.
|
||||
// Creates dtypes.size() nodes in the graph.
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
|
||||
{true, arg_index, 0, false, {dtypes[i]}}));
|
||||
DCHECK_EQ(arg_index, gdef->node_size());
|
||||
DCHECK_EQ(arg_index, result_.nodes.size());
|
||||
string name = arg_def.name();
|
||||
if (dtypes.size() > 1) {
|
||||
strings::StrAppend(&name, "_", i);
|
||||
@ -332,13 +331,13 @@ class FunctionInstantiationHelper {
|
||||
// Adds the actual node inputs to the result graph by converting indexes to
|
||||
// the node names.
|
||||
void AddNodeInputs() {
|
||||
for (int i = 0; i < result_.gdef.node_size(); i++) {
|
||||
for (int i = 0; i < result_.nodes.size(); i++) {
|
||||
NodeInfo& node_info = nodes_[i];
|
||||
for (const auto& p : node_info.data_inputs) {
|
||||
result_.gdef.mutable_node(i)->add_input(Name(p.first, p.second));
|
||||
result_.nodes[i].add_input(Name(p.first, p.second));
|
||||
}
|
||||
for (int index : node_info.control_inputs) {
|
||||
result_.gdef.mutable_node(i)->add_input(Dep(index));
|
||||
result_.nodes[i].add_input(Dep(index));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -348,11 +347,10 @@ class FunctionInstantiationHelper {
|
||||
// node's input arguments.
|
||||
//
|
||||
// If is_func_arg is true, the name is a function's argument. In
|
||||
// this case, the produced graph def has gdef.node[nid ... nid +
|
||||
// dtype.size()).
|
||||
// this case, the produced graph def has node[nid:nid + dtype.size()].
|
||||
//
|
||||
// Otherwise, the name is a function body's node return value. In
|
||||
// this case, the produced graph def has one node gdef.node[nid] and
|
||||
// this case, the produced graph def has one node node[nid] and
|
||||
// the node's output index [idx ... idx + num) corresponds to the
|
||||
// named outputs.
|
||||
//
|
||||
@ -398,10 +396,11 @@ class FunctionInstantiationHelper {
|
||||
}
|
||||
|
||||
NodeDef* AddNode(const string& name) {
|
||||
NodeDef* gnode = result_.gdef.add_node();
|
||||
result_.nodes.emplace_back();
|
||||
NodeDef* gnode = &result_.nodes.back();
|
||||
gnode->set_name(name);
|
||||
nodes_.push_back({name, {}, {}});
|
||||
CHECK_EQ(result_.gdef.node_size(), nodes_.size());
|
||||
CHECK_EQ(result_.nodes.size(), nodes_.size());
|
||||
return gnode;
|
||||
}
|
||||
|
||||
@ -429,7 +428,7 @@ class FunctionInstantiationHelper {
|
||||
// Control inputs (dependencies).
|
||||
std::vector<int> control_inputs;
|
||||
};
|
||||
// nodes_[i] is the information about result_.gdef.node(i).
|
||||
// nodes_[i] is the information about result_.nodes[i].
|
||||
std::vector<NodeInfo> nodes_;
|
||||
};
|
||||
|
||||
@ -545,17 +544,17 @@ string Print(const FunctionDef& fdef) {
|
||||
return out;
|
||||
}
|
||||
|
||||
string Print(const GraphDef& gdef) {
|
||||
string Print(gtl::ArraySlice<const NodeDef*> nodes) {
|
||||
std::vector<const NodeDef*> arg;
|
||||
std::vector<const NodeDef*> ret;
|
||||
std::vector<const NodeDef*> body;
|
||||
for (const NodeDef& n : gdef.node()) {
|
||||
if (n.op() == "_Arg") {
|
||||
arg.push_back(&n);
|
||||
} else if (n.op() == "_Retval") {
|
||||
ret.push_back(&n);
|
||||
for (const NodeDef* n : nodes) {
|
||||
if (n->op() == "_Arg") {
|
||||
arg.push_back(n);
|
||||
} else if (n->op() == "_Retval") {
|
||||
ret.push_back(n);
|
||||
} else {
|
||||
body.push_back(&n);
|
||||
body.push_back(n);
|
||||
}
|
||||
}
|
||||
auto comp = [](const NodeDef* x, const NodeDef* y) {
|
||||
@ -570,12 +569,11 @@ string Print(const GraphDef& gdef) {
|
||||
string out;
|
||||
strings::StrAppend(&out, "\n(");
|
||||
auto get_type = [](const NodeDef& n) {
|
||||
for (auto a : n.attr()) {
|
||||
if (a.first == "T") {
|
||||
return DataTypeString(a.second.type());
|
||||
}
|
||||
DataType dt;
|
||||
if (!GetNodeAttr(n, "T", &dt).ok()) {
|
||||
dt = DT_INVALID;
|
||||
}
|
||||
return DataTypeString(DT_INVALID);
|
||||
return DataTypeString(dt);
|
||||
};
|
||||
for (size_t i = 0; i < arg.size(); ++i) {
|
||||
const NodeDef* n = arg[i];
|
||||
@ -663,13 +661,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
|
||||
result->gdef.node_size() + i);
|
||||
result->nodes.size() + i);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
||||
return s;
|
||||
}
|
||||
}
|
||||
// Emits one gdef.node for each fdef.node_def.
|
||||
// Emits one node for each fdef.node_def.
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
|
||||
if (!s.ok()) {
|
||||
@ -697,7 +695,19 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
string DebugString(const FunctionDef& func_def) { return Print(func_def); }
|
||||
|
||||
string DebugString(const GraphDef& instantiated_func_def) {
|
||||
return Print(instantiated_func_def);
|
||||
std::vector<const NodeDef*> ptrs;
|
||||
for (const NodeDef& n : instantiated_func_def.node()) {
|
||||
ptrs.push_back(&n);
|
||||
}
|
||||
return Print(ptrs);
|
||||
}
|
||||
|
||||
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
|
||||
std::vector<const NodeDef*> ptrs;
|
||||
for (const NodeDef& n : instantiated_func_nodes) {
|
||||
ptrs.push_back(&n);
|
||||
}
|
||||
return Print(ptrs);
|
||||
}
|
||||
|
||||
string DebugStringWhole(const GraphDef& gdef) {
|
||||
|
@ -200,7 +200,7 @@ typedef std::function<Status(const string&, const OpDef**)>
|
||||
struct InstantiationResult {
|
||||
DataTypeVector arg_types;
|
||||
DataTypeVector ret_types;
|
||||
GraphDef gdef;
|
||||
std::vector<NodeDef> nodes;
|
||||
};
|
||||
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
GetFunctionSignature get_function,
|
||||
@ -216,6 +216,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
// etc.)
|
||||
string DebugString(const FunctionDef& func_def);
|
||||
string DebugString(const GraphDef& instantiated_func_def);
|
||||
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes);
|
||||
|
||||
// Returns a debug string for a top level graph (the main program and
|
||||
// its supporting functions defined in its library).
|
||||
|
@ -108,7 +108,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, ControlDep) {
|
||||
@ -154,7 +154,7 @@ ControlDep(x:int32) -> (y:int32) {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_INT32}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_INT32}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
REGISTER_OP("HasDefaultType")
|
||||
@ -198,7 +198,7 @@ BackCompat() -> (y:float) {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector());
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, NTimesT) {
|
||||
@ -234,7 +234,7 @@ NTimesT(x:float, y:float) -> (z:float) {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
// NOTE: This is the simplest Map op. It takes a f:T->U.
|
||||
@ -299,7 +299,7 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, ControlDeps) {
|
||||
@ -344,7 +344,7 @@ ControlDeps(x:float) -> () {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, XTimesTwo) {
|
||||
@ -425,7 +425,7 @@ Test(i:float) -> (o:float) {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
REGISTER_OP("Cond")
|
||||
@ -493,7 +493,7 @@ MySelect(x:float) -> (z:float) {
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
EXPECT_EQ(DebugString(result.nodes), e2);
|
||||
}
|
||||
|
||||
static void HasError(const Status& s, const string& substr) {
|
||||
@ -1028,7 +1028,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) {
|
||||
*proto.add_gradient() = grad;
|
||||
FunctionLibraryDefinition lib_def3(OpRegistry::Global(), proto);
|
||||
TF_EXPECT_OK(lib_def.AddLibrary(lib_def3));
|
||||
};
|
||||
}
|
||||
|
||||
TEST(FunctionLibraryDefinitionTest, ToProto) {
|
||||
FunctionDefLibrary proto1;
|
||||
|
@ -39,6 +39,14 @@ string SummarizeGraphDef(const GraphDef& graph_def) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
string SummarizeGraphDef(gtl::ArraySlice<NodeDef> node_defs) {
|
||||
string ret;
|
||||
for (const NodeDef& node : node_defs) {
|
||||
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
|
||||
for (const NodeDef& node : graph_def.node()) {
|
||||
TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
|
||||
|
@ -27,6 +27,7 @@ namespace tensorflow {
|
||||
// Produce a human-readable version of a GraphDef that is more concise
|
||||
// than a text-format proto.
|
||||
string SummarizeGraphDef(const GraphDef& graph_def);
|
||||
string SummarizeGraphDef(gtl::ArraySlice<NodeDef> node_defs);
|
||||
|
||||
// Validates the syntax of a GraphDef provided externally.
|
||||
//
|
||||
|
@ -91,24 +91,36 @@ class GraphConstructor {
|
||||
bool importing;
|
||||
};
|
||||
|
||||
static Status Construct(const Options& opts, const GraphDef* gdef, Graph* g,
|
||||
typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
|
||||
|
||||
// versions and library may be nullptr
|
||||
static Status Construct(const Options& opts, NodeDefSlice node_defs,
|
||||
const VersionDef* versions,
|
||||
const FunctionDefLibrary* library, Graph* g,
|
||||
ShapeRefiner* refiner,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors) {
|
||||
TF_RETURN_IF_ERROR(CheckVersions(gdef->versions(), TF_GRAPH_DEF_VERSION,
|
||||
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
|
||||
"GraphDef", "graph"));
|
||||
GraphConstructor c(opts, gdef, g, refiner, return_tensors);
|
||||
if (versions) {
|
||||
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
|
||||
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
|
||||
"GraphDef", "graph"));
|
||||
}
|
||||
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
|
||||
return_tensors);
|
||||
const Status s = c.TryImport();
|
||||
if (!s.ok()) c.Undo();
|
||||
return s;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphConstructor(const Options& opts, const GraphDef* gdef, Graph* g,
|
||||
GraphConstructor(const Options& opts, NodeDefSlice node_defs,
|
||||
const VersionDef* versions,
|
||||
const FunctionDefLibrary* library, Graph* g,
|
||||
ShapeRefiner* refiner,
|
||||
std::vector<std::pair<Node*, int>>* return_tensors)
|
||||
: opts_(opts),
|
||||
gdef_(gdef),
|
||||
node_defs_(node_defs),
|
||||
versions_(versions),
|
||||
library_(library),
|
||||
g_(g),
|
||||
original_versions_(g->versions()),
|
||||
refiner_(refiner),
|
||||
@ -159,7 +171,9 @@ class GraphConstructor {
|
||||
|
||||
// From constructor
|
||||
const Options opts_;
|
||||
const GraphDef* gdef_;
|
||||
const NodeDefSlice node_defs_;
|
||||
const VersionDef* versions_;
|
||||
const FunctionDefLibrary* library_;
|
||||
Graph* g_;
|
||||
const VersionDef original_versions_;
|
||||
|
||||
@ -168,7 +182,7 @@ class GraphConstructor {
|
||||
// May be null. Not owned.
|
||||
std::vector<std::pair<Node*, int>>* return_tensors_;
|
||||
|
||||
// Mapping from node name to the index within gdef_
|
||||
// Mapping from node name to the index within node_defs_
|
||||
struct NodeInfo {
|
||||
explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
|
||||
// std::unordered_map<> requires that we have a default constructor.
|
||||
@ -183,18 +197,18 @@ class GraphConstructor {
|
||||
// Mapping from node name to the existing node in g_
|
||||
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> existing_nodes_;
|
||||
|
||||
// Index of NodeDefs in gdef_ with all inputs already converted.
|
||||
// Index of NodeDefs in node_defs_ with all inputs already converted.
|
||||
std::vector<int> ready_;
|
||||
|
||||
// Mapping between index within gdef_ and the number of inputs that
|
||||
// Mapping between index within node_defs_ and the number of inputs that
|
||||
// still need to be converted.
|
||||
std::vector<int> pending_count_;
|
||||
|
||||
// Mapping between index within gdef_ and the index within gdef_ of
|
||||
// Mapping between index within node_defs_ and the index within node_defs_ of
|
||||
// all nodes it outputs to.
|
||||
std::vector<gtl::InlinedVector<int, 4>> outputs_;
|
||||
|
||||
// Used in the conversion from gdef_ to g_ to represent the ith input
|
||||
// Used in the conversion from node_defs_ to g_ to represent the ith input
|
||||
// of a node.
|
||||
struct InputInfo {
|
||||
explicit InputInfo(const string& node_name, Node* n, int i)
|
||||
@ -205,7 +219,7 @@ class GraphConstructor {
|
||||
int index;
|
||||
};
|
||||
|
||||
// Used in the conversion from gdef_ to g_ to represent an edge from
|
||||
// Used in the conversion from node_defs_ to g_ to represent an edge from
|
||||
// the node named 'name' to node 'n'.
|
||||
struct EdgeInfo {
|
||||
explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
|
||||
@ -254,8 +268,8 @@ Status GraphConstructor::EnsureNoNameCollisions() {
|
||||
}
|
||||
}
|
||||
if (opts_.prefix.empty() && opts_.importing) {
|
||||
for (int n = 0; n < gdef_->node_size(); ++n) {
|
||||
const string& name = gdef_->node(n).name();
|
||||
for (const NodeDef* n : node_defs_) {
|
||||
const string& name = n->name();
|
||||
if (existing_nodes_.find(name) != existing_nodes_.end()) {
|
||||
return errors::InvalidArgument("Node '", name,
|
||||
"' already exists in the Graph");
|
||||
@ -312,8 +326,8 @@ Status GraphConstructor::ValidateInputMapAndControlDependencies() {
|
||||
|
||||
Status GraphConstructor::BuildNodeIndex() {
|
||||
// Validate the node names and add them to gdef_nodes_.
|
||||
for (int n = 0; n < gdef_->node_size(); ++n) {
|
||||
const NodeDef& node_def(gdef_->node(n));
|
||||
for (int n = 0; n < node_defs_.size(); ++n) {
|
||||
const NodeDef& node_def = *node_defs_[n];
|
||||
if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
|
||||
return errors::InvalidArgument(
|
||||
"Node '", node_def.name(),
|
||||
@ -351,13 +365,13 @@ Status GraphConstructor::BuildNodeIndex() {
|
||||
}
|
||||
|
||||
Status GraphConstructor::InitFromEdges() {
|
||||
const int num_nodes = gdef_->node_size();
|
||||
const int num_nodes = node_defs_.size();
|
||||
pending_count_.reserve(num_nodes);
|
||||
outputs_.resize(num_nodes);
|
||||
|
||||
// Parse the inputs for each node.
|
||||
for (int n = 0; n < num_nodes; ++n) {
|
||||
const NodeDef& node_def(gdef_->node(n));
|
||||
const NodeDef& node_def = *node_defs_[n];
|
||||
if (IsMerge(node_def)) {
|
||||
// for merge only wait for one non-control input.
|
||||
int32 num_control_edges = 0;
|
||||
@ -489,7 +503,9 @@ Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
|
||||
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
|
||||
AddDefaultsToNodeDef(*op_def, node_def);
|
||||
TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
|
||||
TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, gdef_->versions().producer()));
|
||||
if (versions_) {
|
||||
TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -608,7 +624,9 @@ void GraphConstructor::AddPrefixToNodeDef(
|
||||
Status GraphConstructor::Convert() {
|
||||
// Import functions before adding nodes, since imported nodes may refer to
|
||||
// functions
|
||||
TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(gdef_->library()));
|
||||
if (library_) {
|
||||
TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_));
|
||||
}
|
||||
|
||||
std::vector<InputInfo> inputs;
|
||||
int processed = 0;
|
||||
@ -626,14 +644,14 @@ Status GraphConstructor::Convert() {
|
||||
inputs.clear();
|
||||
bool has_data_back_edge = false;
|
||||
|
||||
const NodeDef& original_node_def = gdef_->node(o);
|
||||
const NodeDef& original_node_def = *node_defs_[o];
|
||||
NodeDef imported_node_def;
|
||||
const NodeDef* node_def;
|
||||
|
||||
// input_already_exists[i] is true iff the i-th input of the node we're
|
||||
// importing refers to a preexisting node in g_ (i.e. input[i] existed prior
|
||||
// to importing gdef_). Conversely, input_already_exists[i] is false iff
|
||||
// the input refers to a node in gdef_.
|
||||
// to importing node_defs_). Conversely, input_already_exists[i] is false
|
||||
// iff the input refers to a node in node_defs_.
|
||||
input_already_exists.clear();
|
||||
input_already_exists.resize(original_node_def.input_size(), false);
|
||||
|
||||
@ -731,8 +749,8 @@ Status GraphConstructor::Convert() {
|
||||
}
|
||||
}
|
||||
|
||||
if (processed < gdef_->node_size()) {
|
||||
return errors::InvalidArgument(gdef_->node_size() - processed,
|
||||
if (processed < node_defs_.size()) {
|
||||
return errors::InvalidArgument(node_defs_.size() - processed,
|
||||
" nodes in a cycle");
|
||||
}
|
||||
return Status::OK();
|
||||
@ -756,20 +774,21 @@ Status GraphConstructor::AddBackEdges() {
|
||||
}
|
||||
|
||||
Status GraphConstructor::UpdateVersionDef() {
|
||||
if (versions_ == nullptr) return Status::OK();
|
||||
|
||||
if (!opts_.importing) {
|
||||
g_->set_versions(gdef_->versions());
|
||||
g_->set_versions(*versions_);
|
||||
return Status::OK();
|
||||
}
|
||||
VersionDef versions = g_->versions();
|
||||
versions.set_producer(
|
||||
std::min(versions.producer(), gdef_->versions().producer()));
|
||||
versions.set_producer(std::min(versions.producer(), versions_->producer()));
|
||||
versions.set_min_consumer(
|
||||
std::max(versions.min_consumer(), gdef_->versions().min_consumer()));
|
||||
if (gdef_->versions().bad_consumers_size() > 0) {
|
||||
std::max(versions.min_consumer(), versions_->min_consumer()));
|
||||
if (versions_->bad_consumers_size() > 0) {
|
||||
std::set<int> bad(versions.bad_consumers().begin(),
|
||||
versions.bad_consumers().end());
|
||||
bad.insert(gdef_->versions().bad_consumers().begin(),
|
||||
gdef_->versions().bad_consumers().end());
|
||||
bad.insert(versions_->bad_consumers().begin(),
|
||||
versions_->bad_consumers().end());
|
||||
versions.clear_bad_consumers();
|
||||
for (int v : bad) {
|
||||
versions.add_bad_consumers(v);
|
||||
@ -837,7 +856,20 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
|
||||
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
||||
const GraphDef& gdef, Graph* g) {
|
||||
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
|
||||
return GraphConstructor::Construct(opts, &gdef, g, &refiner, nullptr);
|
||||
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
|
||||
&gdef.library(), g, &refiner, nullptr);
|
||||
}
|
||||
|
||||
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
||||
gtl::ArraySlice<NodeDef> nodes, Graph* g) {
|
||||
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
|
||||
// TODO(irving): Copy will go away once NodeInfo exists
|
||||
std::vector<const NodeDef*> node_defs;
|
||||
for (const auto& n : nodes) {
|
||||
node_defs.push_back(&n);
|
||||
}
|
||||
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
|
||||
&refiner, nullptr);
|
||||
}
|
||||
|
||||
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
|
||||
@ -886,7 +918,9 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
|
||||
refiner->set_graph_def_version(
|
||||
std::min(refiner->graph_def_version(), gdef.versions().producer()));
|
||||
|
||||
return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors);
|
||||
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
|
||||
&gdef.library(), g, refiner,
|
||||
return_tensors);
|
||||
}
|
||||
|
||||
void CopyGraph(const Graph& src, Graph* dest) {
|
||||
|
@ -46,6 +46,12 @@ struct GraphConstructorOptions {
|
||||
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
|
||||
const GraphDef& gdef, Graph* g);
|
||||
|
||||
// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
|
||||
// instantiation.
|
||||
// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
|
||||
extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
||||
gtl::ArraySlice<NodeDef> nodes, Graph* g);
|
||||
|
||||
// Add the graph in GraphDef gdef into an existing Graph *g.
|
||||
//
|
||||
// On error, returns non-OK and leaves *g unmodified.
|
||||
|
@ -24,16 +24,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
// Intentionally do not check that versions match so that this routine can
|
||||
// be used for less brittle golden file tests.
|
||||
return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
|
||||
}
|
||||
|
||||
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
|
||||
const protobuf::RepeatedPtrField<NodeDef>& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
template <class NodeDefs>
|
||||
static bool EqualNodeDefsHelper(
|
||||
const NodeDefs& actual, const protobuf::RepeatedPtrField<NodeDef>& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
std::unordered_map<string, const NodeDef*> actual_index;
|
||||
for (const NodeDef& node : actual) {
|
||||
actual_index[node.name()] = &node;
|
||||
@ -68,6 +62,24 @@ bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
// Intentionally do not check that versions match so that this routine can
|
||||
// be used for less brittle golden file tests.
|
||||
return EqualNodeDefsHelper(actual.node(), expected.node(), diff, options);
|
||||
}
|
||||
|
||||
bool EqualGraphDef(gtl::ArraySlice<NodeDef> actual, const GraphDef& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
return EqualNodeDefsHelper(actual, expected.node(), diff, options);
|
||||
}
|
||||
|
||||
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
|
||||
const protobuf::RepeatedPtrField<NodeDef>& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
return EqualNodeDefsHelper(actual, expected, diff, options);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
|
||||
|
@ -36,6 +36,8 @@ struct EqualGraphDefOptions {
|
||||
// nodes must be consistent.
|
||||
bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
string* diff, const EqualGraphDefOptions& options = {});
|
||||
bool EqualGraphDef(gtl::ArraySlice<NodeDef> actual, const GraphDef& expected,
|
||||
string* diff, const EqualGraphDefOptions& options = {});
|
||||
|
||||
// Determines if actual and expected are equal, ignoring: ordering of
|
||||
// attrs, internal attributes (if set in `options`), and control inputs.
|
||||
|
Loading…
Reference in New Issue
Block a user