Automated g4 rollback of changelist 161087978
PiperOrigin-RevId: 161111023
This commit is contained in:
parent
cab048ecde
commit
27b341c800
@ -156,7 +156,7 @@ tf_library(
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the unknown op is not needed for the fetches.
|
||||
# the compilation works because the the unknown op is not needed for the fetches.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop",
|
||||
testonly = 1,
|
||||
@ -166,29 +166,6 @@ tf_library(
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the op between the unknown op and the
|
||||
# fetches is a feed.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop2",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop2.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the unknown op is fed.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop3",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop3.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
# Utility library for benchmark binaries, used by the *_benchmark rules that are
|
||||
# added by the tfcompile bazel macro.
|
||||
cc_library(
|
||||
|
@ -78,51 +78,66 @@ Status DumpGraph(const MainFlags& flags, const string& name,
|
||||
return WriteTextProto(Env::Default(), file, graph_def);
|
||||
}
|
||||
|
||||
string TensorIdToString(const TensorId& id) {
|
||||
return strings::StrCat(id.node_name(), ":", id.output_index());
|
||||
}
|
||||
|
||||
typedef std::unordered_map<string, Node*> NodeMap;
|
||||
|
||||
// Each feed id identifies the positional output of some node, which may consist
|
||||
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
|
||||
// tensor with a placeholder. For each feed tensor, replaces all edges so they
|
||||
// point from a new _Arg node instead.
|
||||
// of multiple edges. For each feed node, replaces all matching edges so that
|
||||
// they point from a new _Arg node instead.
|
||||
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||
const protobuf::RepeatedPtrField<Feed>& feeds,
|
||||
const std::unordered_map<string, string>& feed_remapping) {
|
||||
const protobuf::RepeatedPtrField<Feed>& feeds) {
|
||||
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
|
||||
const Feed& feed = feeds[arg_index];
|
||||
// All feeds have been replaced by placeholders.
|
||||
const int output_index = 0;
|
||||
|
||||
const auto remap_it = feed_remapping.find(TensorIdToString(feed.id()));
|
||||
auto node_it = node_map.find(remap_it->second);
|
||||
const Node* feed_node = node_it->second;
|
||||
|
||||
// TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
|
||||
// "_shape" attr if we can determine it. That way the graph will be
|
||||
// initialized with whatever shapes we can infer, while the user can still
|
||||
// explicitly specify or override them.
|
||||
const TensorId& id = feed.id();
|
||||
auto it = node_map.find(id.node_name());
|
||||
if (it == node_map.end()) {
|
||||
return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
|
||||
}
|
||||
const Node* feed_node = it->second;
|
||||
if (id.output_index() >= feed_node->num_outputs()) {
|
||||
return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
|
||||
", output index should be < ",
|
||||
feed_node->num_outputs());
|
||||
}
|
||||
// TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
|
||||
// if we can determine it. That way the graph will be initialized with
|
||||
// whatever shapes we can infer, while the user can still explicitly specify
|
||||
// or override them.
|
||||
Node* arg_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
|
||||
.Attr("T", BaseType(feed_node->output_type(output_index)))
|
||||
.Attr("T", BaseType(feed_node->output_type(id.output_index())))
|
||||
.Attr("index", arg_index)
|
||||
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
|
||||
.Attr(kFeedIdAttr, TensorIdToString(id))
|
||||
.Attr(kShapeAttr, TensorShape(feed.shape()))
|
||||
.Attr(kDebugNameAttr, feed.name())
|
||||
.Finalize(graph, &arg_node));
|
||||
|
||||
// Collects out-edges from the feed node that have a matching edge index;
|
||||
// these will be replaced with edges from the arg node instead.
|
||||
// these will be replaced with edges from the arg node instead. Also
|
||||
// replaces all control edges from Placeholder feed nodes; similar code
|
||||
// exists in subgraph::RewriteGraphForExecution.
|
||||
// TODO(toddw): Why only replace control edges from Placeholder?
|
||||
//
|
||||
// We must collect the edges first and process them in a second pass, since
|
||||
// removing the edge from the graph invalidates feed_node->out_edges.
|
||||
std::vector<const Edge*> feed_edges;
|
||||
for (const Edge* edge : feed_node->out_edges()) {
|
||||
if (edge->src_output() == output_index) {
|
||||
if (edge->src_output() == id.output_index() ||
|
||||
(edge->src_output() == Graph::kControlSlot &&
|
||||
feed_node->type_string() == "Placeholder")) {
|
||||
feed_edges.push_back(edge);
|
||||
}
|
||||
}
|
||||
for (const Edge* edge : feed_edges) {
|
||||
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
|
||||
if (edge->src_output() == id.output_index()) {
|
||||
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
|
||||
} else {
|
||||
CHECK_EQ(edge->src_output(), Graph::kControlSlot);
|
||||
graph->AddControlEdge(arg_node, edge->dst());
|
||||
}
|
||||
graph->RemoveEdge(edge);
|
||||
}
|
||||
}
|
||||
@ -164,16 +179,13 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
|
||||
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
|
||||
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
|
||||
// execution to know the input and output args for the generated function.
|
||||
Status RewriteAndPruneGraph(
|
||||
Graph* graph, const Config& config,
|
||||
const std::unordered_map<string, string>& feed_remapping,
|
||||
const MainFlags& flags) {
|
||||
Status RewriteAndPruneGraph(Graph* graph, const Config& config,
|
||||
const MainFlags& flags) {
|
||||
NodeMap node_map;
|
||||
for (Node* n : graph->nodes()) {
|
||||
node_map[n->name()] = n;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
|
||||
TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
|
||||
std::unordered_set<const Node*> retval_nodes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||
@ -371,28 +383,17 @@ Status InitGraph(const GraphDef& graph_def, const Config& config,
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
|
||||
std::unique_ptr<Graph> g(new Graph(flib_def));
|
||||
|
||||
// Replace references to fed tensors with references to newly added
|
||||
// placeholders.
|
||||
GraphDef first_copy_def = graph_def;
|
||||
|
||||
// Maps from name:port of a feed to the name:port of the placeholder to use.
|
||||
std::unordered_map<string, string> feed_remapping;
|
||||
TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
|
||||
&feed_remapping, &first_copy_def));
|
||||
GraphDef copy_def;
|
||||
|
||||
// Prune the GraphDef first so that unknown ops that we aren't compiling get
|
||||
// filtered out.
|
||||
GraphDef second_copy_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
|
||||
TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, ©_def));
|
||||
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
|
||||
&second_copy_def, *g->op_registry(), 0 /*node_offset*/));
|
||||
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
|
||||
second_copy_def, g.get()));
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(),
|
||||
0 /*node_offset*/));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
|
||||
ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
|
||||
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
|
||||
*graph = std::move(g);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -6,12 +6,21 @@ node {
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape { dim { size: 1 } }
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
attr { key : "dtype" value { type: DT_INT32 } }
|
||||
attr {
|
||||
key : "dtype"
|
||||
value {
|
||||
type : DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "y_const"
|
||||
@ -21,37 +30,56 @@ node {
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape { dim { size: 1 } }
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
int_val: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
attr { key: "dtype" value { type: DT_INT32 } }
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "x_y_sum"
|
||||
op : "Add"
|
||||
input : "x_const"
|
||||
input : "y_const"
|
||||
attr { key : "T" value { type: DT_INT32 } }
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "z"
|
||||
op : "SomeUnknownOp"
|
||||
input : "x_const"
|
||||
}
|
||||
node {
|
||||
name : "z_identity"
|
||||
op : "Identity"
|
||||
input : "z:1"
|
||||
attr { key : "T" value { type: DT_INT32 } }
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name : "x_z_sum"
|
||||
op : "Add"
|
||||
input : "x_const"
|
||||
input : "z_identity"
|
||||
attr { key : "T" value { type: DT_INT32 } }
|
||||
input : "z"
|
||||
attr {
|
||||
key : "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 15
|
||||
|
@ -1,25 +0,0 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "z_identity"}
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_z_sum" }
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
# Text form of tensorflow.tfcompile.Config proto.
|
||||
feed {
|
||||
id { node_name: "x_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "y_const" }
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
}
|
||||
feed {
|
||||
id { node_name: "z" output_index: 1}
|
||||
shape {
|
||||
dim { size: 1 }
|
||||
}
|
||||
type: DT_INT32
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_y_sum" }
|
||||
}
|
||||
fetch {
|
||||
id { node_name: "x_z_sum" }
|
||||
}
|
@ -7,7 +7,6 @@ option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.tfcompile";
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
|
||||
// index of a particular node in the graph. If the output of the named node
|
||||
@ -24,12 +23,6 @@ message Feed {
|
||||
TensorId id = 1;
|
||||
TensorShapeProto shape = 2;
|
||||
string name = 3; // Optional name for generated code.
|
||||
|
||||
// Optional data type. This is not normally required, as the graph itself
|
||||
// contains this information. However, if the node being fed is an op that
|
||||
// is not linked into the tfcompile binary, then the type cannot be inferred
|
||||
// from the node; in this case, the type should be set here.
|
||||
DataType type = 4;
|
||||
};
|
||||
|
||||
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
@ -120,105 +119,17 @@ Status ValidateConfig(const Config& config) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddPlaceholdersForFeeds(
|
||||
const Config& config, const OpRegistryInterface* op_registry,
|
||||
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
|
||||
struct PlaceholderInfo {
|
||||
const Feed* feed = nullptr; // point to Feed in <config>.
|
||||
string placeholder_name;
|
||||
DataType data_type = DT_INVALID;
|
||||
};
|
||||
|
||||
// Put each fed tensor into a map by name:port. A map is used for determinism
|
||||
// when creating placeholders (genrules want deterministic output).
|
||||
std::map<string, PlaceholderInfo> placeholder_info;
|
||||
for (int i = 0; i < config.feed_size(); ++i) {
|
||||
const Feed* feed = &config.feed(i);
|
||||
const string name_port = TensorIdToString(feed->id());
|
||||
auto& info = placeholder_info[name_port];
|
||||
info.feed = feed;
|
||||
info.placeholder_name = strings::StrCat(
|
||||
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
|
||||
(*feed_remapping)[name_port] = info.placeholder_name;
|
||||
}
|
||||
|
||||
// Verify node exists and determine data type.
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
for (int i = 0; i < graph_def->node_size(); ++i) {
|
||||
name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
|
||||
}
|
||||
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
|
||||
PlaceholderInfo& info = it->second;
|
||||
const TensorId& feed_id = info.feed->id();
|
||||
|
||||
// Find the existing node and determine data type.
|
||||
auto node_it = name_to_node.find(feed_id.node_name());
|
||||
if (node_it == name_to_node.end()) {
|
||||
return errors::NotFound("Can't find feed node: ",
|
||||
TensorIdToString(feed_id));
|
||||
}
|
||||
const NodeDef* existing = node_it->second;
|
||||
|
||||
if (info.feed->type() != DT_INVALID) {
|
||||
info.data_type = info.feed->type();
|
||||
} else {
|
||||
// Build the node in order to infer its type.
|
||||
Graph g(op_registry);
|
||||
Status status;
|
||||
Node* feed_node = g.AddNode(*existing, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
info.data_type =
|
||||
BaseType(feed_node->output_type(info.feed->id().output_index()));
|
||||
}
|
||||
}
|
||||
|
||||
// Create placeholders. Note that we could avoid creating a placeholder for
|
||||
// feeds which are already placeholders, but we omit that to avoid more cases
|
||||
// in this code.
|
||||
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
|
||||
const PlaceholderInfo& info = it->second;
|
||||
NodeDef* d = graph_def->add_node();
|
||||
d->set_name(info.placeholder_name);
|
||||
d->set_op("PlaceholderV2");
|
||||
auto& attr_map = *d->mutable_attr();
|
||||
attr_map["dtype"].set_type(info.data_type);
|
||||
*attr_map["shape"].mutable_shape() = info.feed->shape();
|
||||
}
|
||||
|
||||
// Rewrite references to the fed tensors to refer to the placeholder.
|
||||
for (int i = 0; i < graph_def->node_size(); ++i) {
|
||||
NodeDef* node_def = graph_def->mutable_node(i);
|
||||
for (int j = 0; j < node_def->input_size(); ++j) {
|
||||
auto id = ParseTensorName(node_def->input(j));
|
||||
auto it = placeholder_info.find(id.ToString());
|
||||
if (it != placeholder_info.end()) {
|
||||
node_def->set_input(j, it->second.placeholder_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
GraphDef* out) {
|
||||
*out = in;
|
||||
out->clear_node();
|
||||
|
||||
// Tensors needed for feeding.
|
||||
std::set<std::pair<string, int>> feed_tensors;
|
||||
for (const auto& feed_config : config.feed()) {
|
||||
feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
|
||||
feed_config.id().output_index()));
|
||||
}
|
||||
|
||||
// Maps node name to reachability.
|
||||
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
|
||||
for (const NodeDef& node : in.node()) {
|
||||
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
|
||||
}
|
||||
|
||||
// Traverse.
|
||||
std::queue<string> name_queue;
|
||||
for (int i = 0; i < config.fetch_size(); ++i) {
|
||||
name_queue.push(config.fetch(i).id().node_name());
|
||||
@ -238,19 +149,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
}
|
||||
map_entry.first = true;
|
||||
|
||||
// Push input nodes of the currently visited node to name_queue.
|
||||
for (const string& in_edge : map_entry.second->input()) {
|
||||
auto id = ParseTensorName(in_edge);
|
||||
const string node_name = id.first.ToString();
|
||||
if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
|
||||
feed_tensors.end()) {
|
||||
name_queue.push(node_name);
|
||||
} else {
|
||||
// The input tensor is from an edge that is being fed. Therefore,
|
||||
// we skip recursing down that edge, to avoid requiring nodes that
|
||||
// may not be needed (note that the input node may still be added
|
||||
// to name_queue later if one of its output edges is not being fed).
|
||||
}
|
||||
name_queue.push(ParseTensorName(in_edge).first.ToString());
|
||||
}
|
||||
}
|
||||
|
||||
@ -265,9 +165,5 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string TensorIdToString(const TensorId& id) {
|
||||
return strings::StrCat(id.node_name(), ":", id.output_index());
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -16,11 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
@ -34,23 +31,11 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg);
|
||||
// ValidateConfig returns OK iff config is valid.
|
||||
Status ValidateConfig(const Config& config);
|
||||
|
||||
// Modifies <graph_def> to include placeholders for each fed tensor, and
|
||||
// update references to the fed tensors to refer to the placeholders.
|
||||
// The existing nodes referenced by the feeds are not removed or modified
|
||||
// (except where their input edges are modified by the replacement of other
|
||||
// feeds).
|
||||
Status AddPlaceholdersForFeeds(
|
||||
const Config& config, const OpRegistryInterface* op_registry,
|
||||
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
|
||||
|
||||
// Returns in <out> a copy of <in>, pruned to only include fetches from
|
||||
// <config>.
|
||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||
GraphDef* out);
|
||||
|
||||
// Returns node:port for the given <id>.
|
||||
string TensorIdToString(const TensorId& id);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user