Fix TF2XLA's InitGraph for unused feeds.
If a feed is not used, previously it would prune the placeholders and cause crashes. PiperOrigin-RevId: 311754319 Change-Id: Ie1ad67c21ffb83ba88aeabea94c416473df099a0
This commit is contained in:
parent
53c634a6c1
commit
2540d202b5
@ -49,10 +49,12 @@ typedef std::unordered_map<string, Node*> NodeMap;
|
||||
// Each feed id identifies the positional output of some node, which may consist
|
||||
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
|
||||
// tensor with a placeholder. For each feed tensor, replaces all edges so they
|
||||
// point from a new _Arg node instead.
|
||||
// point from a new _Arg node instead. The newly created _Arg nodes are added to
|
||||
// `arg_nodes`.
|
||||
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
|
||||
const std::unordered_map<string, string>& feed_remapping) {
|
||||
const std::unordered_map<string, string>& feed_remapping,
|
||||
std::unordered_set<const Node*>* arg_nodes) {
|
||||
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
|
||||
const tf2xla::Feed& feed = feeds[arg_index];
|
||||
// All feeds have been replaced by placeholders.
|
||||
@ -86,6 +88,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
.Attr(kShapeAttr, TensorShape(feed.shape()))
|
||||
.Attr(kDebugNameAttr, feed.name())
|
||||
.Finalize(graph, &arg_node));
|
||||
arg_nodes->insert(arg_node);
|
||||
|
||||
// Collects out-edges from the feed node that have a matching edge index;
|
||||
// these will be replaced with edges from the arg node instead.
|
||||
@ -149,13 +152,13 @@ Status RewriteAndPruneGraph(
|
||||
for (Node* n : graph->nodes()) {
|
||||
node_map[n->name()] = n;
|
||||
}
|
||||
std::unordered_set<const Node*> nodes_to_keep;
|
||||
TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed(), feed_remapping,
|
||||
&nodes_to_keep));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
|
||||
std::unordered_set<const Node*> retval_nodes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &nodes_to_keep));
|
||||
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
|
||||
PruneForReverseReachability(graph, std::move(retval_nodes));
|
||||
PruneForReverseReachability(graph, std::move(nodes_to_keep));
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
|
||||
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
||||
@ -277,8 +280,16 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
// Prune the GraphDef first so that unknown ops that we aren't compiling get
|
||||
// filtered out.
|
||||
GraphDef second_copy_def;
|
||||
// Add the placeholder nodes as "fetches" in prune_config, such that they will
|
||||
// be preserved in PruneGraphDefInto.
|
||||
auto prune_config = config;
|
||||
for (const auto& entry : feed_remapping) {
|
||||
auto ph = prune_config.add_fetch();
|
||||
*ph->mutable_id()->mutable_node_name() = entry.second;
|
||||
ph->mutable_id()->set_output_index(0);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
|
||||
PruneGraphDefInto(prune_config, first_copy_def, &second_copy_def));
|
||||
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
|
||||
&second_copy_def, *g->op_registry(), /*node_offset=*/0));
|
||||
|
@ -99,5 +99,42 @@ TEST(ConvertGraphDefToXla, Sum) {
|
||||
ConvertGraphDefToXla(graph_def, config, client, &computation)));
|
||||
}
|
||||
|
||||
TEST(ConvertGraphDefToXla, SumWithUnusedArgument) {
|
||||
GraphDef graph_def = SumGraph();
|
||||
tf2xla::Config config = SumConfig();
|
||||
NodeDef* unused = graph_def.add_node();
|
||||
unused->set_name("unused");
|
||||
unused->set_op("Placeholder");
|
||||
(*unused->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
|
||||
config.add_feed()->mutable_id()->set_node_name("unused");
|
||||
|
||||
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
|
||||
xla::XlaComputation computation;
|
||||
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation));
|
||||
|
||||
// Set up arguments.
|
||||
auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
|
||||
auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
|
||||
auto x_global_or = client->TransferToServer(x_literal);
|
||||
auto y_global_or = client->TransferToServer(y_literal);
|
||||
auto unused_global_or = client->TransferToServer(y_literal);
|
||||
TF_EXPECT_OK(x_global_or.status());
|
||||
TF_EXPECT_OK(y_global_or.status());
|
||||
TF_EXPECT_OK(unused_global_or.status());
|
||||
std::unique_ptr<xla::GlobalData> x_global =
|
||||
std::move(x_global_or.ValueOrDie());
|
||||
std::unique_ptr<xla::GlobalData> y_global =
|
||||
std::move(y_global_or.ValueOrDie());
|
||||
std::unique_ptr<xla::GlobalData> unused_global =
|
||||
std::move(unused_global_or.ValueOrDie());
|
||||
|
||||
// Execute and check result.
|
||||
auto result_or = client->ExecuteAndTransfer(
|
||||
computation, {x_global.get(), y_global.get(), unused_global.get()});
|
||||
TF_EXPECT_OK(result_or.status());
|
||||
xla::Literal result = std::move(result_or.ValueOrDie());
|
||||
EXPECT_EQ("(\ns32[] 42\n)", result.ToString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user