Cleanup of transitive-fanin node computation to make details of invalid input available to its callers.
PiperOrigin-RevId: 307497151 Change-Id: I0e59172a24918b5de95832147d07e1e2fe6c0fe6
This commit is contained in:
parent
819330a213
commit
8f439272ed
tensorflow/core/grappler
@ -270,6 +270,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:utils",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -405,14 +406,10 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
}
|
||||
|
||||
// Get the nodes that would run to output fetch_nodes.
|
||||
bool ill_formed = false;
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
const std::vector<const NodeDef*> fetch_fanin_nodes =
|
||||
ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node, &ill_formed);
|
||||
if (ill_formed) {
|
||||
return errors::InvalidArgument(
|
||||
"Ill formed graph or invalid set of fetch nodes specified");
|
||||
}
|
||||
std::vector<const NodeDef*> fetch_fanin_nodes;
|
||||
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node,
|
||||
&fetch_fanin_nodes));
|
||||
|
||||
// Once ComputeTransitiveFanin is complete, only the nodes that can be reached
|
||||
// from the fetch nodes are scheduled. So the scheduled nodes should be
|
||||
|
@ -44,7 +44,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
@ -54,9 +54,9 @@ void MaybePruneGraph(const tensorflow::MetaGraphDef& metagraph,
|
||||
if (fetch_nodes.empty()) {
|
||||
*graph = metagraph.graph_def();
|
||||
} else {
|
||||
std::vector<const tensorflow::NodeDef*> fanin_nodes =
|
||||
tensorflow::grappler::ComputeTransitiveFanin(metagraph.graph_def(),
|
||||
fetch_nodes);
|
||||
std::vector<const tensorflow::NodeDef*> fanin_nodes;
|
||||
TF_CHECK_OK(tensorflow::grappler::ComputeTransitiveFanin(
|
||||
metagraph.graph_def(), fetch_nodes, &fanin_nodes));
|
||||
for (const tensorflow::NodeDef* node : fanin_nodes) {
|
||||
*(graph->add_node()) = *node;
|
||||
}
|
||||
|
@ -50,7 +50,9 @@ GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
|
||||
return ComputeTransitiveFanin(graph, fetch);
|
||||
std::vector<const NodeDef*> fanin_nodes;
|
||||
TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
|
||||
return fanin_nodes;
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
|
||||
@ -60,15 +62,20 @@ std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
|
||||
enqueue_ops.push_back(enqueue_op);
|
||||
}
|
||||
}
|
||||
return ComputeTransitiveFanin(graph, enqueue_ops);
|
||||
std::vector<const NodeDef*> fanin_nodes;
|
||||
TF_CHECK_OK(ComputeTransitiveFanin(graph, fetch, &fanin_nodes));
|
||||
return fanin_nodes;
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
|
||||
return ComputeTransitiveFanin(graph, init_ops);
|
||||
std::vector<const NodeDef*> fanin_nodes;
|
||||
TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin_nodes));
|
||||
return fanin_nodes;
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
|
||||
std::vector<const NodeDef*> fanin = ComputeTransitiveFanin(graph, init_ops);
|
||||
std::vector<const NodeDef*> fanin;
|
||||
TF_CHECK_OK(ComputeTransitiveFanin(graph, init_ops, &fanin));
|
||||
std::vector<const NodeDef*> vars;
|
||||
for (const NodeDef* node : fanin) {
|
||||
if (IsVariable(*node)) {
|
||||
@ -200,22 +207,5 @@ GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
|
||||
return optimization_options_;
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes) {
|
||||
bool ill_formed = false;
|
||||
std::vector<const NodeDef*> result =
|
||||
ComputeTransitiveFanin(graph, terminal_nodes, &ill_formed);
|
||||
CHECK(!ill_formed);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
bool* ill_formed) {
|
||||
std::unordered_map<string, const NodeDef*> name_to_fanin_node;
|
||||
return ComputeTransitiveFanin(graph, terminal_nodes, &name_to_fanin_node,
|
||||
ill_formed);
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -133,25 +133,6 @@ struct GrapplerItem {
|
||||
OptimizationOptions optimization_options_;
|
||||
};
|
||||
|
||||
// Return the transitive fanin of a set of terminal nodes.
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes);
|
||||
|
||||
// Return the transitive fanin of a set of terminal nodes. Sets 'ill_formed' to
|
||||
// true if one of the node is missing in the graph, or some node inputs don't
|
||||
// exist.
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
bool* ill_formed);
|
||||
|
||||
// Return the transitive fanin of a set of terminal nodes. Sets 'ill_formed' to
|
||||
// true if one of the node is missing in the graph, or some node inputs don't
|
||||
// exist. Sets name_to_fanin_node for name to fanin nodes map.
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
std::unordered_map<string, const NodeDef*>* name_to_fanin_node,
|
||||
bool* ill_formed);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -75,6 +75,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -147,7 +148,8 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
|
||||
}
|
||||
LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
|
||||
|
||||
auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch);
|
||||
std::vector<const NodeDef*> train_nodes;
|
||||
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, item.fetch, &train_nodes));
|
||||
LOG(INFO) << "Number of training nodes: " << train_nodes.size();
|
||||
|
||||
const NodeDef* dequeue_node;
|
||||
@ -161,7 +163,8 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
|
||||
std::vector<const NodeDef*> input_nodes;
|
||||
if (dequeue_node) {
|
||||
LOG(INFO) << "Dequeue node: " << dequeue_node->name();
|
||||
input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()});
|
||||
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, {dequeue_node->name()},
|
||||
{}, &input_nodes));
|
||||
}
|
||||
LOG(INFO) << "Number of input nodes: " << input_nodes.size();
|
||||
|
||||
|
@ -20,15 +20,15 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
Status ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
std::unordered_map<string, const NodeDef*>* name_to_fanin_node,
|
||||
bool* ill_formed) {
|
||||
*ill_formed = false;
|
||||
std::vector<const NodeDef*>* fanin_nodes) {
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
std::unordered_map<string, const NodeDef*> name_to_send;
|
||||
for (const auto& node : graph.node()) {
|
||||
@ -43,14 +43,12 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
for (const string& root : terminal_nodes) {
|
||||
const NodeDef* node = name_to_node[NodeName(root)];
|
||||
if (!node) {
|
||||
*ill_formed = true;
|
||||
VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
|
||||
return {};
|
||||
return errors::InvalidArgument("Graph does not contain terminal node ",
|
||||
root, ".");
|
||||
}
|
||||
queue.push_back(node);
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> result;
|
||||
std::unordered_set<const NodeDef*> visited;
|
||||
|
||||
while (!queue.empty()) {
|
||||
@ -60,15 +58,17 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
// The node has already been visited.
|
||||
continue;
|
||||
}
|
||||
result.push_back(node);
|
||||
name_to_fanin_node->insert(
|
||||
std::pair<string, const NodeDef*>(node->name(), node));
|
||||
fanin_nodes->push_back(node);
|
||||
if (name_to_fanin_node) {
|
||||
name_to_fanin_node->insert(
|
||||
std::pair<string, const NodeDef*>(node->name(), node));
|
||||
}
|
||||
for (const string& input : node->input()) {
|
||||
const NodeDef* in = name_to_node[NodeName(input)];
|
||||
if (!in) {
|
||||
VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
|
||||
*ill_formed = true;
|
||||
return {};
|
||||
return errors::InvalidArgument("Graph does not contain input ",
|
||||
NodeName(input), " of node ",
|
||||
node->name(), ".");
|
||||
}
|
||||
queue.push_back(in);
|
||||
}
|
||||
@ -82,7 +82,13 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
// So, we do not set ill_formed for missing _Send.
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeTransitiveFanin(const GraphDef& graph,
|
||||
const std::vector<string>& terminal_nodes,
|
||||
std::vector<const NodeDef*>* fanin_nodes) {
|
||||
return ComputeTransitiveFanin(graph, terminal_nodes, nullptr, fanin_nodes);
|
||||
}
|
||||
|
||||
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
|
||||
@ -90,15 +96,9 @@ Status SetTransitiveFaninGraph(const GraphDef& input_graph,
|
||||
const std::vector<string>& terminal_nodes) {
|
||||
// Determines transitive fanin nodes from terminal nodes and add them to the
|
||||
// output graph.
|
||||
bool ill_formed = false;
|
||||
std::unordered_map<string, const NodeDef*> name_to_fanin_node;
|
||||
std::vector<const NodeDef*> keep = ComputeTransitiveFanin(
|
||||
input_graph, terminal_nodes, &name_to_fanin_node, &ill_formed);
|
||||
if (ill_formed) {
|
||||
// Some graph edges are invalid, or some of the feeds/fetch don't exist:
|
||||
// let's be conservative and preserve the graph as is.
|
||||
return errors::InvalidArgument("Invalid input graph.");
|
||||
}
|
||||
std::vector<const NodeDef*> keep;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ComputeTransitiveFanin(input_graph, terminal_nodes, &keep));
|
||||
// Try to keep the nodes ordered somewhat topologically since this helps
|
||||
// further optimizations perform better.
|
||||
output_graph->mutable_node()->Reserve(keep.size());
|
||||
|
@ -25,13 +25,17 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Computes the transitive fanin of the graph based on reachability from the
|
||||
// specified terminal nodes. ill_formed will be set to true if the graph is
|
||||
// deemed structurally invalid. Returns the set of nodes comprising the
|
||||
// transitive fanin.
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
// specified terminal nodes. Returns the set of nodes comprising the
|
||||
// transitive fanin into fanin_nodes. Optionally returns a map of name->node
|
||||
// for that graph into name_to_fanin_node if that is not set to nullptr.
|
||||
Status ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
std::unordered_map<string, const NodeDef*>* name_to_fanin_node,
|
||||
bool* ill_formed);
|
||||
std::vector<const NodeDef*>* fanin_nodes);
|
||||
|
||||
Status ComputeTransitiveFanin(const GraphDef& graph,
|
||||
const std::vector<string>& terminal_nodes,
|
||||
std::vector<const NodeDef*>* fanin_nodes);
|
||||
|
||||
// Creates output_graph from input_graph using the transitive fanin from the
|
||||
// specified terminal nodes. Returns error if the input_graph is deemed
|
||||
|
@ -117,7 +117,7 @@ TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromMultipleTerminalNodes) {
|
||||
ASSERT_FALSE(node_map.NodeExists("6"));
|
||||
}
|
||||
|
||||
TEST_F(TransitiveFaninTest, InvalidGraph) {
|
||||
TEST_F(TransitiveFaninTest, InvalidGraphOrTerminalNodes) {
|
||||
GraphDef graph = CreateGraph({
|
||||
{"1", {"2"}}, //
|
||||
{"2", {"3"}}, //
|
||||
@ -131,7 +131,11 @@ TEST_F(TransitiveFaninTest, InvalidGraph) {
|
||||
const std::vector<string> terminal_nodes = {"1", "5"};
|
||||
auto s = SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_EQ(s.error_message(), "Invalid input graph.");
|
||||
EXPECT_EQ(s.error_message(), "Graph does not contain input 6 of node 5.");
|
||||
const std::vector<string> invalid_terminal_nodes = {"0", "1", "5"};
|
||||
s = SetTransitiveFaninGraph(graph, &output_graph, invalid_terminal_nodes);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_EQ(s.error_message(), "Graph does not contain terminal node 0.");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user