[Grappler] Add a simple graph structure verifier.
Currently, the checks include simple structure verification like no duplicate node names, no illegal loops and node def and op registry matches. We will continue to add more checks. PiperOrigin-RevId: 230784105
This commit is contained in:
parent
d3c849152b
commit
c91dd82d8a
tensorflow/core
@ -2831,6 +2831,7 @@ tf_cuda_library(
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -113,5 +115,16 @@ Status ValidateGraphHasNoCycle(const Graph& graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyNoDuplicateNodeNames(const GraphDef& graph) {
|
||||
absl::flat_hash_set<absl::string_view> nodes;
|
||||
for (const auto& node : graph.node()) {
|
||||
if (nodes.contains(node.name())) {
|
||||
return errors::AlreadyExists("Node already exists: ", node.name());
|
||||
}
|
||||
nodes.insert(node.name());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace graph
|
||||
} // namespace tensorflow
|
||||
|
@ -59,6 +59,9 @@ void GetOpListForValidation(
|
||||
// be less than the total node count.
|
||||
Status ValidateGraphHasNoCycle(const Graph& graph);
|
||||
|
||||
// Returns OK if the graph has no duplicate node names.
|
||||
Status VerifyNoDuplicateNodeNames(const GraphDef& graph);
|
||||
|
||||
} // namespace graph
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -147,5 +147,36 @@ TEST(GetOpListForValidationTest, ShouldStripDocs) {
|
||||
EXPECT_TRUE(found_has_docs);
|
||||
}
|
||||
|
||||
TEST(VerifyNoDuplicateNodeNames, NoDuplicateNodeNames) {
|
||||
const string graph_def_str =
|
||||
"node { name: 'A' op: 'FloatInput' }"
|
||||
"node { name: 'B' op: 'Int32Input' }"
|
||||
"node { "
|
||||
" name: 'C' op: 'Sum' "
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] "
|
||||
"}";
|
||||
GraphDef graph_def;
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
|
||||
TF_ASSERT_OK(graph::VerifyNoDuplicateNodeNames(graph_def));
|
||||
}
|
||||
|
||||
TEST(VerifyNoDuplicateNodeNames, DuplicateNodeNames) {
|
||||
const string graph_def_str =
|
||||
"node { name: 'A' op: 'FloatInput' }"
|
||||
"node { name: 'A' op: 'Int32Input' }"
|
||||
"node { "
|
||||
" name: 'C' op: 'Sum' "
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'A'] "
|
||||
"}";
|
||||
GraphDef graph_def;
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
|
||||
EXPECT_EQ(tensorflow::error::ALREADY_EXISTS,
|
||||
graph::VerifyNoDuplicateNodeNames(graph_def).code());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,5 +1,7 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "graph_verifier",
|
||||
hdrs = [
|
||||
@ -7,7 +9,43 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "structure_verifier",
|
||||
srcs = ["structure_verifier.cc"],
|
||||
hdrs = [
|
||||
"structure_verifier.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_verifier",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "structure_verifier_test",
|
||||
srcs = ["structure_verifier_test.cc"],
|
||||
deps = [
|
||||
":structure_verifier",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -44,7 +44,9 @@ class GraphVerifier {
|
||||
virtual string name() const = 0;
|
||||
|
||||
// Implement an algorithm to verify the specified graph.
|
||||
virtual Status Verify(const GraphDef& graph, std::vector<string>* errors) = 0;
|
||||
// The return value is a Status that represents a concatenation of Status of
|
||||
// each verification step.
|
||||
virtual Status Verify(const GraphDef& graph) = 0;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
45
tensorflow/core/grappler/verifiers/structure_verifier.cc
Normal file
45
tensorflow/core/grappler/verifiers/structure_verifier.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/verifiers/structure_verifier.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/grappler/verifiers/graph_verifier.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// TODO(ashwinm): Expand this to add more structural checks.
|
||||
Status StructureVerifier::Verify(const GraphDef& graph) {
|
||||
StatusGroup status_group;
|
||||
status_group.Update(tensorflow::graph::ValidateGraphDefAgainstOpRegistry(
|
||||
graph, *OpRegistry::Global()));
|
||||
status_group.Update(tensorflow::graph::VerifyNoDuplicateNodeNames(graph));
|
||||
|
||||
std::vector<const NodeDef*> topo_order;
|
||||
status_group.Update(ComputeTopologicalOrder(graph, &topo_order));
|
||||
return status_group.as_status();
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
43
tensorflow/core/grappler/verifiers/structure_verifier.h
Normal file
43
tensorflow/core/grappler/verifiers/structure_verifier.h
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_VERIFIERS_STRUCTURE_VERIFIER_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_VERIFIERS_STRUCTURE_VERIFIER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/verifiers/graph_verifier.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Verifies the structure of a graph to ensure it is valid.
|
||||
class StructureVerifier : public GraphVerifier {
|
||||
public:
|
||||
StructureVerifier() {}
|
||||
~StructureVerifier() override {}
|
||||
|
||||
string name() const override { return "structure_verifier"; };
|
||||
|
||||
Status Verify(const GraphDef& graph) override;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_VERIFIERS_STRUCTURE_VERIFIER_H_
|
116
tensorflow/core/grappler/verifiers/structure_verifier_test.cc
Normal file
116
tensorflow/core/grappler/verifiers/structure_verifier_test.cc
Normal file
@ -0,0 +1,116 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/cc/ops/parsing_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/grappler/verifiers/structure_verifier.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class StructureVerifierTest : public ::testing::Test {
|
||||
protected:
|
||||
StructureVerifierTest() { verifier_.reset(new StructureVerifier()); }
|
||||
void SetGraph(const string& gdef_ascii) {
|
||||
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &graph_));
|
||||
}
|
||||
GraphDef graph_;
|
||||
std::unique_ptr<StructureVerifier> verifier_;
|
||||
};
|
||||
|
||||
Status Scalars(shape_inference::InferenceContext* c) {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->Scalar());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("TestParams").Output("o: float").SetShapeFn(Scalars);
|
||||
REGISTER_OP("TestInput")
|
||||
.Output("a: float")
|
||||
.Output("b: float")
|
||||
.SetShapeFn(Scalars);
|
||||
REGISTER_OP("TestMul")
|
||||
.Input("a: float")
|
||||
.Input("b: float")
|
||||
.Output("o: float")
|
||||
.SetShapeFn(Scalars);
|
||||
|
||||
TEST_F(StructureVerifierTest, ValidGraphs) {
|
||||
// With scope, ops gets registered automatically.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
|
||||
|
||||
GraphDef graph;
|
||||
TF_CHECK_OK(s.ToGraphDef(&graph));
|
||||
TF_EXPECT_OK(verifier_->Verify(graph));
|
||||
|
||||
// With graphdef directly, relies on REGISTER_OP to register ops
|
||||
SetGraph(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'input' op: 'TestInput' }"
|
||||
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }");
|
||||
|
||||
TF_EXPECT_OK(verifier_->Verify(graph_));
|
||||
}
|
||||
|
||||
TEST_F(StructureVerifierTest, OpNotRegistered) {
|
||||
SetGraph(
|
||||
"node { name: 'input' op: 'OpNotRegistered' }"
|
||||
"node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }"
|
||||
"node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }");
|
||||
Status status = verifier_->Verify(graph_);
|
||||
EXPECT_EQ(errors::Code::NOT_FOUND, status.code());
|
||||
EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(),
|
||||
"Op type not registered"));
|
||||
}
|
||||
|
||||
TEST_F(StructureVerifierTest, DuplicateNodeNames) {
|
||||
SetGraph(
|
||||
"node { name: 'A' op: 'TestParams' }"
|
||||
"node { name: 'A' op: 'TestInput' }");
|
||||
Status status = verifier_->Verify(graph_);
|
||||
EXPECT_EQ(errors::Code::ALREADY_EXISTS, status.code());
|
||||
EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(),
|
||||
"Node already exists:"));
|
||||
}
|
||||
|
||||
TEST_F(StructureVerifierTest, GraphWithInvalidCycle) {
|
||||
SetGraph(
|
||||
"node { name: 'input' op: 'TestInput' }"
|
||||
"node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }"
|
||||
"node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }");
|
||||
Status status = verifier_->Verify(graph_);
|
||||
EXPECT_EQ(errors::Code::INVALID_ARGUMENT, status.code());
|
||||
EXPECT_TRUE(tensorflow::str_util::StrContains(
|
||||
status.error_message(),
|
||||
"The graph couldn't be sorted in topological order"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user