Cleanup of transitive-fanin node computation to make details of invalid input available to its callers.

PiperOrigin-RevId: 307497151
Change-Id: I0e59172a24918b5de95832147d07e1e2fe6c0fe6
This commit is contained in:
A. Unique TensorFlower 2020-04-20 16:15:50 -07:00 committed by TensorFlower Gardener
parent 819330a213
commit 8f439272ed
11 changed files with 65 additions and 84 deletions

View File

@ -270,6 +270,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:utils",
"//tensorflow/core/grappler/utils:transitive_fanin",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
@ -405,14 +406,10 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
}
// Get the nodes that would run to output fetch_nodes.
bool ill_formed = false;
std::unordered_map<string, const NodeDef*> name_to_node;
const std::vector<const NodeDef*> fetch_fanin_nodes =
ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node, &ill_formed);
if (ill_formed) {
return errors::InvalidArgument(
"Ill formed graph or invalid set of fetch nodes specified");
}
std::vector<const NodeDef*> fetch_fanin_nodes;
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node,
&fetch_fanin_nodes));
// Once ComputeTransitiveFanin is complete, only the nodes that can be reached
// from the fetch nodes are scheduled. So the scheduled nodes should be

View File

@ -44,7 +44,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:transitive_fanin",
"@com_google_absl//absl/strings",
],
)

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@ -54,9 +54,9 @@ void MaybePruneGraph(const tensorflow::MetaGraphDef& metagraph,
if (fetch_nodes.empty()) {
*graph = metagraph.graph_def();
} else {
std::vector<const tensorflow::NodeDef*> fanin_nodes =
tensorflow::grappler::ComputeTransitiveFanin(metagraph.graph_def(),
fetch_nodes);
std::vector<const tensorflow::NodeDef*> fanin_nodes;
TF_CHECK_OK(tensorflow::grappler::ComputeTransitiveFanin(
metagraph.graph_def(), fetch_nodes, &fanin_nodes));
for (const tensorflow::NodeDef* node : fanin_nodes) {
*(graph->add_node()) = *node;
}

View File

@ -50,7 +50,9 @@ GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
}
std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
return ComputeTransitiveFanin(graph, fetch);
std::vector<const NodeDef*> fanin_nodes;
TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
return fanin_nodes;
}
std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
@ -60,15 +62,20 @@ std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
enqueue_ops.push_back(enqueue_op);
}
}
return ComputeTransitiveFanin(graph, enqueue_ops);
std::vector<const NodeDef*> fanin_nodes;
TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
return fanin_nodes;
}
std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
return ComputeTransitiveFanin(graph, init_ops);
std::vector<const NodeDef*> fanin_nodes;
TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin_nodes));
return fanin_nodes;
}
std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
std::vector<const NodeDef*> fanin = ComputeTransitiveFanin(graph, init_ops);
std::vector<const NodeDef*> fanin;
TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin));
std::vector<const NodeDef*> vars;
for (const NodeDef* node : fanin) {
if (IsVariable(*node)) {
@ -200,22 +207,5 @@ GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
return optimization_options_;
}
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes) {
bool ill_formed = false;
std::vector<const NodeDef*> result =
ComputeTransitiveFanin(graph, terminal_nodes, &ill_formed);
CHECK(!ill_formed);
return result;
}
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
bool* ill_formed) {
std::unordered_map<string, const NodeDef*> name_to_fanin_node;
return ComputeTransitiveFanin(graph, terminal_nodes, &name_to_fanin_node,
ill_formed);
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -133,25 +133,6 @@ struct GrapplerItem {
OptimizationOptions optimization_options_;
};
// Return the transitive fanin of a set of terminal nodes.
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes);
// Return the transitive fanin of a set of terminal nodes. Sets 'ill_formed' to
// true if one of the node is missing in the graph, or some node inputs don't
// exist.
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
bool* ill_formed);
// Return the transitive fanin of a set of terminal nodes. Sets 'ill_formed' to
// true if one of the node is missing in the graph, or some node inputs don't
// exist. Sets name_to_fanin_node for name to fanin nodes map.
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
std::unordered_map<string, const NodeDef*>* name_to_fanin_node,
bool* ill_formed);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -75,6 +75,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/utils:transitive_fanin",
],
)

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@ -147,7 +148,8 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
}
LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch);
std::vector<const NodeDef*> train_nodes;
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, item.fetch, &train_nodes));
LOG(INFO) << "Number of training nodes: " << train_nodes.size();
const NodeDef* dequeue_node;
@ -161,7 +163,8 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
std::vector<const NodeDef*> input_nodes;
if (dequeue_node) {
LOG(INFO) << "Dequeue node: " << dequeue_node->name();
input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()});
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, {dequeue_node->name()},
{}, &input_nodes));
}
LOG(INFO) << "Number of input nodes: " << input_nodes.size();

View File

@ -20,15 +20,15 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace grappler {
std::vector<const NodeDef*> ComputeTransitiveFanin(
Status ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
std::unordered_map<string, const NodeDef*>* name_to_fanin_node,
bool* ill_formed) {
*ill_formed = false;
std::vector<const NodeDef*>* fanin_nodes) {
std::unordered_map<string, const NodeDef*> name_to_node;
std::unordered_map<string, const NodeDef*> name_to_send;
for (const auto& node : graph.node()) {
@ -43,14 +43,12 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
for (const string& root : terminal_nodes) {
const NodeDef* node = name_to_node[NodeName(root)];
if (!node) {
*ill_formed = true;
VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
return {};
return errors::InvalidArgument("Graph does not contain terminal node ",
root, ".");
}
queue.push_back(node);
}
std::vector<const NodeDef*> result;
std::unordered_set<const NodeDef*> visited;
while (!queue.empty()) {
@ -60,15 +58,17 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
// The node has already been visited.
continue;
}
result.push_back(node);
name_to_fanin_node->insert(
std::pair<string, const NodeDef*>(node->name(), node));
fanin_nodes->push_back(node);
if (name_to_fanin_node) {
name_to_fanin_node->insert(
std::pair<string, const NodeDef*>(node->name(), node));
}
for (const string& input : node->input()) {
const NodeDef* in = name_to_node[NodeName(input)];
if (!in) {
VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
*ill_formed = true;
return {};
return errors::InvalidArgument("Graph does not contain input ",
NodeName(input), " of node ",
node->name(), ".");
}
queue.push_back(in);
}
@ -82,7 +82,13 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
// So, we do not set ill_formed for missing _Send.
}
}
return result;
return Status::OK();
}
Status ComputeTransitiveFanin(const GraphDef& graph,
const std::vector<string>& terminal_nodes,
std::vector<const NodeDef*>* fanin_nodes) {
return ComputeTransitiveFanin(graph, terminal_nodes, nullptr, fanin_nodes);
}
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
@ -90,15 +96,9 @@ Status SetTransitiveFaninGraph(const GraphDef& input_graph,
const std::vector<string>& terminal_nodes) {
// Determines transitive fanin nodes from terminal nodes and add them to the
// output graph.
bool ill_formed = false;
std::unordered_map<string, const NodeDef*> name_to_fanin_node;
std::vector<const NodeDef*> keep = ComputeTransitiveFanin(
input_graph, terminal_nodes, &name_to_fanin_node, &ill_formed);
if (ill_formed) {
// Some graph edges are invalid, or some of the feeds/fetch don't exist:
// let's be conservative and preserve the graph as is.
return errors::InvalidArgument("Invalid input graph.");
}
std::vector<const NodeDef*> keep;
TF_RETURN_IF_ERROR(
ComputeTransitiveFanin(input_graph, terminal_nodes, &keep));
// Try to keep the nodes ordered somewhat topologically since this helps
// further optimizations perform better.
output_graph->mutable_node()->Reserve(keep.size());

View File

@ -25,13 +25,17 @@ namespace tensorflow {
namespace grappler {
// Computes the transitive fanin of the graph based on reachability from the
// specified terminal nodes. ill_formed will be set to true if the graph is
// deemed structurally invalid. Returns the set of nodes comprising the
// transitive fanin.
std::vector<const NodeDef*> ComputeTransitiveFanin(
// specified terminal nodes. Returns the set of nodes comprising the
// transitive fanin into fanin_nodes. Optionally returns a map of name->node
// for that graph into name_to_fanin_node if that is not set to nullptr.
Status ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
std::unordered_map<string, const NodeDef*>* name_to_fanin_node,
bool* ill_formed);
std::vector<const NodeDef*>* fanin_nodes);
Status ComputeTransitiveFanin(const GraphDef& graph,
const std::vector<string>& terminal_nodes,
std::vector<const NodeDef*>* fanin_nodes);
// Creates output_graph from input_graph using the transitive fanin from the
// specified terminal nodes. Returns error if the input_graph is deemed

View File

@ -117,7 +117,7 @@ TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromMultipleTerminalNodes) {
ASSERT_FALSE(node_map.NodeExists("6"));
}
TEST_F(TransitiveFaninTest, InvalidGraph) {
TEST_F(TransitiveFaninTest, InvalidGraphOrTerminalNodes) {
GraphDef graph = CreateGraph({
{"1", {"2"}}, //
{"2", {"3"}}, //
@ -131,7 +131,11 @@ TEST_F(TransitiveFaninTest, InvalidGraph) {
const std::vector<string> terminal_nodes = {"1", "5"};
auto s = SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes);
EXPECT_FALSE(s.ok());
EXPECT_EQ(s.error_message(), "Invalid input graph.");
EXPECT_EQ(s.error_message(), "Graph does not contain input 6 of node 5.");
const std::vector<string> invalid_terminal_nodes = {"0", "1", "5"};
s = SetTransitiveFaninGraph(graph, &output_graph, invalid_terminal_nodes);
EXPECT_FALSE(s.ok());
EXPECT_EQ(s.error_message(), "Graph does not contain terminal node 0.");
}
} // namespace