Fix TF2XLA's InitGraph for unused feeds.

If a feed is not used, previously it would prune the placeholders and cause crashes.

PiperOrigin-RevId: 311754319
Change-Id: Ie1ad67c21ffb83ba88aeabea94c416473df099a0
This commit is contained in:
Yuanzhong Xu 2020-05-15 10:19:17 -07:00 committed by TensorFlower Gardener
parent 53c634a6c1
commit 2540d202b5
2 changed files with 56 additions and 8 deletions

View File

@ -49,10 +49,12 @@ typedef std::unordered_map<string, Node*> NodeMap;
// Each feed id identifies the positional output of some node, which may consist
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
// tensor with a placeholder. For each feed tensor, replaces all edges so they
// point from a new _Arg node instead.
// point from a new _Arg node instead. The newly created _Arg nodes are added to
// `arg_nodes`.
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
const std::unordered_map<string, string>& feed_remapping) {
const std::unordered_map<string, string>& feed_remapping,
std::unordered_set<const Node*>* arg_nodes) {
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
const tf2xla::Feed& feed = feeds[arg_index];
// All feeds have been replaced by placeholders.
@ -86,6 +88,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map,
.Attr(kShapeAttr, TensorShape(feed.shape()))
.Attr(kDebugNameAttr, feed.name())
.Finalize(graph, &arg_node));
arg_nodes->insert(arg_node);
// Collects out-edges from the feed node that have a matching edge index;
// these will be replaced with edges from the arg node instead.
@ -149,13 +152,13 @@ Status RewriteAndPruneGraph(
for (Node* n : graph->nodes()) {
node_map[n->name()] = n;
}
std::unordered_set<const Node*> nodes_to_keep;
TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed(), feed_remapping,
&nodes_to_keep));
TF_RETURN_IF_ERROR(
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
std::unordered_set<const Node*> retval_nodes;
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
AddRetvalNodes(graph, node_map, config.fetch(), &nodes_to_keep));
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
PruneForReverseReachability(graph, std::move(retval_nodes));
PruneForReverseReachability(graph, std::move(nodes_to_keep));
FixupSourceAndSinkEdges(graph);
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
@ -277,8 +280,16 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
// Prune the GraphDef first so that unknown ops that we aren't compiling get
// filtered out.
GraphDef second_copy_def;
// Add the placeholder nodes as "fetches" in prune_config, such that they will
// be preserved in PruneGraphDefInto.
auto prune_config = config;
for (const auto& entry : feed_remapping) {
auto ph = prune_config.add_fetch();
*ph->mutable_id()->mutable_node_name() = entry.second;
ph->mutable_id()->set_output_index(0);
}
TF_RETURN_IF_ERROR(
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
PruneGraphDefInto(prune_config, first_copy_def, &second_copy_def));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
&second_copy_def, *g->op_registry(), /*node_offset=*/0));

View File

@ -99,5 +99,42 @@ TEST(ConvertGraphDefToXla, Sum) {
ConvertGraphDefToXla(graph_def, config, client, &computation)));
}
TEST(ConvertGraphDefToXla, SumWithUnusedArgument) {
GraphDef graph_def = SumGraph();
tf2xla::Config config = SumConfig();
NodeDef* unused = graph_def.add_node();
unused->set_name("unused");
unused->set_op("Placeholder");
(*unused->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
config.add_feed()->mutable_id()->set_node_name("unused");
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
xla::XlaComputation computation;
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation));
// Set up arguments.
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
auto x_global_or = client->TransferToServer(x_literal);
auto y_global_or = client->TransferToServer(y_literal);
auto unused_global_or = client->TransferToServer(y_literal);
TF_EXPECT_OK(x_global_or.status());
TF_EXPECT_OK(y_global_or.status());
TF_EXPECT_OK(unused_global_or.status());
std::unique_ptr<xla::GlobalData> x_global =
std::move(x_global_or.ValueOrDie());
std::unique_ptr<xla::GlobalData> y_global =
std::move(y_global_or.ValueOrDie());
std::unique_ptr<xla::GlobalData> unused_global =
std::move(unused_global_or.ValueOrDie());
// Execute and check result.
auto result_or = client->ExecuteAndTransfer(
computation, {x_global.get(), y_global.get(), unused_global.get()});
TF_EXPECT_OK(result_or.status());
xla::Literal result = std::move(result_or.ValueOrDie());
EXPECT_EQ("(\ns32[] 42\n)", result.ToString());
}
} // namespace
} // namespace tensorflow