Automated g4 rollback of changelist 161087978

PiperOrigin-RevId: 161111023
This commit is contained in:
A. Unique TensorFlower 2017-07-06 11:59:53 -07:00 committed by TensorFlower Gardener
parent cab048ecde
commit 27b341c800
8 changed files with 89 additions and 260 deletions

View File

@ -156,7 +156,7 @@ tf_library(
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the unknown op is not needed for the fetches.
# the compilation works because the the unknown op is not needed for the fetches.
tf_library(
name = "test_graph_tfunknownop",
testonly = 1,
@ -166,29 +166,6 @@ tf_library(
tags = ["manual"],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the op between the unknown op and the
# fetches is a feed.
tf_library(
name = "test_graph_tfunknownop2",
testonly = 1,
config = "test_graph_tfunknownop2.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
tags = ["manual"],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the unknown op is fed.
tf_library(
name = "test_graph_tfunknownop3",
testonly = 1,
config = "test_graph_tfunknownop3.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
tags = ["manual"],
)
# Utility library for benchmark binaries, used by the *_benchmark rules that are
# added by the tfcompile bazel macro.
cc_library(

View File

@ -78,51 +78,66 @@ Status DumpGraph(const MainFlags& flags, const string& name,
return WriteTextProto(Env::Default(), file, graph_def);
}
string TensorIdToString(const TensorId& id) {
return strings::StrCat(id.node_name(), ":", id.output_index());
}
typedef std::unordered_map<string, Node*> NodeMap;
// Each feed id identifies the positional output of some node, which may consist
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
// tensor with a placeholder. For each feed tensor, replaces all edges so they
// point from a new _Arg node instead.
// of multiple edges. For each feed node, replaces all matching edges so that
// they point from a new _Arg node instead.
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
const protobuf::RepeatedPtrField<Feed>& feeds,
const std::unordered_map<string, string>& feed_remapping) {
const protobuf::RepeatedPtrField<Feed>& feeds) {
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
const Feed& feed = feeds[arg_index];
// All feeds have been replaced by placeholders.
const int output_index = 0;
const auto remap_it = feed_remapping.find(TensorIdToString(feed.id()));
auto node_it = node_map.find(remap_it->second);
const Node* feed_node = node_it->second;
// TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
// "_shape" attr if we can determine it. That way the graph will be
// initialized with whatever shapes we can infer, while the user can still
// explicitly specify or override them.
const TensorId& id = feed.id();
auto it = node_map.find(id.node_name());
if (it == node_map.end()) {
return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
}
const Node* feed_node = it->second;
if (id.output_index() >= feed_node->num_outputs()) {
return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
", output index should be < ",
feed_node->num_outputs());
}
// TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
// if we can determine it. That way the graph will be initialized with
// whatever shapes we can infer, while the user can still explicitly specify
// or override them.
Node* arg_node = nullptr;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
.Attr("T", BaseType(feed_node->output_type(output_index)))
.Attr("T", BaseType(feed_node->output_type(id.output_index())))
.Attr("index", arg_index)
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
.Attr(kFeedIdAttr, TensorIdToString(id))
.Attr(kShapeAttr, TensorShape(feed.shape()))
.Attr(kDebugNameAttr, feed.name())
.Finalize(graph, &arg_node));
// Collects out-edges from the feed node that have a matching edge index;
// these will be replaced with edges from the arg node instead.
// these will be replaced with edges from the arg node instead. Also
// replaces all control edges from Placeholder feed nodes; similar code
// exists in subgraph::RewriteGraphForExecution.
// TODO(toddw): Why only replace control edges from Placeholder?
//
// We must collect the edges first and process them in a second pass, since
// removing the edge from the graph invalidates feed_node->out_edges.
std::vector<const Edge*> feed_edges;
for (const Edge* edge : feed_node->out_edges()) {
if (edge->src_output() == output_index) {
if (edge->src_output() == id.output_index() ||
(edge->src_output() == Graph::kControlSlot &&
feed_node->type_string() == "Placeholder")) {
feed_edges.push_back(edge);
}
}
for (const Edge* edge : feed_edges) {
if (edge->src_output() == id.output_index()) {
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
} else {
CHECK_EQ(edge->src_output(), Graph::kControlSlot);
graph->AddControlEdge(arg_node, edge->dst());
}
graph->RemoveEdge(edge);
}
}
@ -164,16 +179,13 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
// execution to know the input and output args for the generated function.
Status RewriteAndPruneGraph(
Graph* graph, const Config& config,
const std::unordered_map<string, string>& feed_remapping,
Status RewriteAndPruneGraph(Graph* graph, const Config& config,
const MainFlags& flags) {
NodeMap node_map;
for (Node* n : graph->nodes()) {
node_map[n->name()] = n;
}
TF_RETURN_IF_ERROR(
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
std::unordered_set<const Node*> retval_nodes;
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
@ -371,28 +383,17 @@ Status InitGraph(const GraphDef& graph_def, const Config& config,
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
std::unique_ptr<Graph> g(new Graph(flib_def));
// Replace references to fed tensors with references to newly added
// placeholders.
GraphDef first_copy_def = graph_def;
// Maps from name:port of a feed to the name:port of the placeholder to use.
std::unordered_map<string, string> feed_remapping;
TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
&feed_remapping, &first_copy_def));
GraphDef copy_def;
// Prune the GraphDef first so that unknown ops that we aren't compiling get
// filtered out.
GraphDef second_copy_def;
TF_RETURN_IF_ERROR(
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, &copy_def));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
&second_copy_def, *g->op_registry(), 0 /*node_offset*/));
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
second_copy_def, g.get()));
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&copy_def, *g->op_registry(),
0 /*node_offset*/));
TF_RETURN_IF_ERROR(
RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
*graph = std::move(g);
return Status::OK();
}

View File

@ -6,12 +6,21 @@ node {
value {
tensor {
dtype: DT_INT32
tensor_shape { dim { size: 1 } }
tensor_shape {
dim {
size: 1
}
}
int_val: 1
}
}
}
attr { key : "dtype" value { type: DT_INT32 } }
attr {
key : "dtype"
value {
type : DT_INT32
}
}
}
node {
name : "y_const"
@ -21,37 +30,56 @@ node {
value {
tensor {
dtype: DT_INT32
tensor_shape { dim { size: 1 } }
tensor_shape {
dim {
size: 1
}
}
int_val: 2
}
}
}
attr { key: "dtype" value { type: DT_INT32 } }
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name : "x_y_sum"
op : "Add"
input : "x_const"
input : "y_const"
attr { key : "T" value { type: DT_INT32 } }
attr {
key : "T"
value {
type: DT_INT32
}
}
}
node {
name : "z"
op : "SomeUnknownOp"
input : "x_const"
attr {
key : "T"
value {
type: DT_INT32
}
}
node {
name : "z_identity"
op : "Identity"
input : "z:1"
attr { key : "T" value { type: DT_INT32 } }
}
node {
name : "x_z_sum"
op : "Add"
input : "x_const"
input : "z_identity"
attr { key : "T" value { type: DT_INT32 } }
input : "z"
attr {
key : "T"
value {
type: DT_INT32
}
}
}
versions {
producer: 15

View File

@ -1,25 +0,0 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "z_identity"}
shape {
dim { size: 1 }
}
}
fetch {
id { node_name: "x_y_sum" }
}
fetch {
id { node_name: "x_z_sum" }
}

View File

@ -1,26 +0,0 @@
# Text form of tensorflow.tfcompile.Config proto.
feed {
id { node_name: "x_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "y_const" }
shape {
dim { size: 1 }
}
}
feed {
id { node_name: "z" output_index: 1}
shape {
dim { size: 1 }
}
type: DT_INT32
}
fetch {
id { node_name: "x_y_sum" }
}
fetch {
id { node_name: "x_z_sum" }
}

View File

@ -7,7 +7,6 @@ option java_multiple_files = true;
option java_package = "org.tensorflow.tfcompile";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
// index of a particular node in the graph. If the output of the named node
@ -24,12 +23,6 @@ message Feed {
TensorId id = 1;
TensorShapeProto shape = 2;
string name = 3; // Optional name for generated code.
// Optional data type. This is not normally required, as the graph itself
// contains this information. However, if the node being fed is an op that
// is not linked into the tfcompile binary, then the type cannot be inferred
// from the node; in this case, the type should be set here.
DataType type = 4;
};
// Fetch represents a single fetch tensor in the graph, which corresponds to an

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace tfcompile {
@ -120,105 +119,17 @@ Status ValidateConfig(const Config& config) {
return Status::OK();
}
Status AddPlaceholdersForFeeds(
const Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
struct PlaceholderInfo {
const Feed* feed = nullptr; // point to Feed in <config>.
string placeholder_name;
DataType data_type = DT_INVALID;
};
// Put each fed tensor into a map by name:port. A map is used for determinism
// when creating placeholders (genrules want deterministic output).
std::map<string, PlaceholderInfo> placeholder_info;
for (int i = 0; i < config.feed_size(); ++i) {
const Feed* feed = &config.feed(i);
const string name_port = TensorIdToString(feed->id());
auto& info = placeholder_info[name_port];
info.feed = feed;
info.placeholder_name = strings::StrCat(
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
(*feed_remapping)[name_port] = info.placeholder_name;
}
// Verify node exists and determine data type.
std::unordered_map<string, const NodeDef*> name_to_node;
for (int i = 0; i < graph_def->node_size(); ++i) {
name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
}
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
PlaceholderInfo& info = it->second;
const TensorId& feed_id = info.feed->id();
// Find the existing node and determine data type.
auto node_it = name_to_node.find(feed_id.node_name());
if (node_it == name_to_node.end()) {
return errors::NotFound("Can't find feed node: ",
TensorIdToString(feed_id));
}
const NodeDef* existing = node_it->second;
if (info.feed->type() != DT_INVALID) {
info.data_type = info.feed->type();
} else {
// Build the node in order to infer its type.
Graph g(op_registry);
Status status;
Node* feed_node = g.AddNode(*existing, &status);
TF_RETURN_IF_ERROR(status);
info.data_type =
BaseType(feed_node->output_type(info.feed->id().output_index()));
}
}
// Create placeholders. Note that we could avoid creating a placeholder for
// feeds which are already placeholders, but we omit that to avoid more cases
// in this code.
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
const PlaceholderInfo& info = it->second;
NodeDef* d = graph_def->add_node();
d->set_name(info.placeholder_name);
d->set_op("PlaceholderV2");
auto& attr_map = *d->mutable_attr();
attr_map["dtype"].set_type(info.data_type);
*attr_map["shape"].mutable_shape() = info.feed->shape();
}
// Rewrite references to the fed tensors to refer to the placeholder.
for (int i = 0; i < graph_def->node_size(); ++i) {
NodeDef* node_def = graph_def->mutable_node(i);
for (int j = 0; j < node_def->input_size(); ++j) {
auto id = ParseTensorName(node_def->input(j));
auto it = placeholder_info.find(id.ToString());
if (it != placeholder_info.end()) {
node_def->set_input(j, it->second.placeholder_name);
}
}
}
return Status::OK();
}
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
GraphDef* out) {
*out = in;
out->clear_node();
// Tensors needed for feeding.
std::set<std::pair<string, int>> feed_tensors;
for (const auto& feed_config : config.feed()) {
feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
feed_config.id().output_index()));
}
// Maps node name to reachability.
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
for (const NodeDef& node : in.node()) {
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
}
// Traverse.
std::queue<string> name_queue;
for (int i = 0; i < config.fetch_size(); ++i) {
name_queue.push(config.fetch(i).id().node_name());
@ -238,19 +149,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
}
map_entry.first = true;
// Push input nodes of the currently visited node to name_queue.
for (const string& in_edge : map_entry.second->input()) {
auto id = ParseTensorName(in_edge);
const string node_name = id.first.ToString();
if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
feed_tensors.end()) {
name_queue.push(node_name);
} else {
// The input tensor is from an edge that is being fed. Therefore,
// we skip recursing down that edge, to avoid requiring nodes that
// may not be needed (note that the input node may still be added
// to name_queue later if one of its output edges is not being fed).
}
name_queue.push(ParseTensorName(in_edge).first.ToString());
}
}
@ -265,9 +165,5 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
return Status::OK();
}
string TensorIdToString(const TensorId& id) {
return strings::StrCat(id.node_name(), ":", id.output_index());
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -16,11 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
#include <unordered_map>
#include "tensorflow/compiler/aot/tfcompile.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@ -34,23 +31,11 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg);
// ValidateConfig returns OK iff config is valid.
Status ValidateConfig(const Config& config);
// Modifies <graph_def> to include placeholders for each fed tensor, and
// update references to the fed tensors to refer to the placeholders.
// The existing nodes referenced by the feeds are not removed or modified
// (except where their input edges are modified by the replacement of other
// feeds).
Status AddPlaceholdersForFeeds(
const Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
// Returns in <out> a copy of <in>, pruned to only include fetches from
// <config>.
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
GraphDef* out);
// Returns node:port for the given <id>.
string TensorIdToString(const TensorId& id);
} // namespace tfcompile
} // namespace tensorflow