Automated g4 rollback of changelist 161087978
PiperOrigin-RevId: 161111023
This commit is contained in:
parent
cab048ecde
commit
27b341c800
@ -156,7 +156,7 @@ tf_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# A test of tf_library that includes a graph with an unknown op, but where
|
# A test of tf_library that includes a graph with an unknown op, but where
|
||||||
# the compilation works because the unknown op is not needed for the fetches.
|
# the compilation works because the the unknown op is not needed for the fetches.
|
||||||
tf_library(
|
tf_library(
|
||||||
name = "test_graph_tfunknownop",
|
name = "test_graph_tfunknownop",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
@ -166,29 +166,6 @@ tf_library(
|
|||||||
tags = ["manual"],
|
tags = ["manual"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# A test of tf_library that includes a graph with an unknown op, but where
|
|
||||||
# the compilation works because the op between the unknown op and the
|
|
||||||
# fetches is a feed.
|
|
||||||
tf_library(
|
|
||||||
name = "test_graph_tfunknownop2",
|
|
||||||
testonly = 1,
|
|
||||||
config = "test_graph_tfunknownop2.config.pbtxt",
|
|
||||||
cpp_class = "UnknownOpAddComp",
|
|
||||||
graph = "test_graph_tfunknownop.pbtxt",
|
|
||||||
tags = ["manual"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# A test of tf_library that includes a graph with an unknown op, but where
|
|
||||||
# the compilation works because the unknown op is fed.
|
|
||||||
tf_library(
|
|
||||||
name = "test_graph_tfunknownop3",
|
|
||||||
testonly = 1,
|
|
||||||
config = "test_graph_tfunknownop3.config.pbtxt",
|
|
||||||
cpp_class = "UnknownOpAddComp",
|
|
||||||
graph = "test_graph_tfunknownop.pbtxt",
|
|
||||||
tags = ["manual"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Utility library for benchmark binaries, used by the *_benchmark rules that are
|
# Utility library for benchmark binaries, used by the *_benchmark rules that are
|
||||||
# added by the tfcompile bazel macro.
|
# added by the tfcompile bazel macro.
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -78,51 +78,66 @@ Status DumpGraph(const MainFlags& flags, const string& name,
|
|||||||
return WriteTextProto(Env::Default(), file, graph_def);
|
return WriteTextProto(Env::Default(), file, graph_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string TensorIdToString(const TensorId& id) {
|
||||||
|
return strings::StrCat(id.node_name(), ":", id.output_index());
|
||||||
|
}
|
||||||
|
|
||||||
typedef std::unordered_map<string, Node*> NodeMap;
|
typedef std::unordered_map<string, Node*> NodeMap;
|
||||||
|
|
||||||
// Each feed id identifies the positional output of some node, which may consist
|
// Each feed id identifies the positional output of some node, which may consist
|
||||||
// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
|
// of multiple edges. For each feed node, replaces all matching edges so that
|
||||||
// tensor with a placeholder. For each feed tensor, replaces all edges so they
|
// they point from a new _Arg node instead.
|
||||||
// point from a new _Arg node instead.
|
|
||||||
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
Status AddArgNodes(Graph* graph, const NodeMap& node_map,
|
||||||
const protobuf::RepeatedPtrField<Feed>& feeds,
|
const protobuf::RepeatedPtrField<Feed>& feeds) {
|
||||||
const std::unordered_map<string, string>& feed_remapping) {
|
|
||||||
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
|
for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
|
||||||
const Feed& feed = feeds[arg_index];
|
const Feed& feed = feeds[arg_index];
|
||||||
// All feeds have been replaced by placeholders.
|
const TensorId& id = feed.id();
|
||||||
const int output_index = 0;
|
auto it = node_map.find(id.node_name());
|
||||||
|
if (it == node_map.end()) {
|
||||||
const auto remap_it = feed_remapping.find(TensorIdToString(feed.id()));
|
return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
|
||||||
auto node_it = node_map.find(remap_it->second);
|
}
|
||||||
const Node* feed_node = node_it->second;
|
const Node* feed_node = it->second;
|
||||||
|
if (id.output_index() >= feed_node->num_outputs()) {
|
||||||
// TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
|
return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
|
||||||
// "_shape" attr if we can determine it. That way the graph will be
|
", output index should be < ",
|
||||||
// initialized with whatever shapes we can infer, while the user can still
|
feed_node->num_outputs());
|
||||||
// explicitly specify or override them.
|
}
|
||||||
|
// TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
|
||||||
|
// if we can determine it. That way the graph will be initialized with
|
||||||
|
// whatever shapes we can infer, while the user can still explicitly specify
|
||||||
|
// or override them.
|
||||||
Node* arg_node = nullptr;
|
Node* arg_node = nullptr;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
|
NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
|
||||||
.Attr("T", BaseType(feed_node->output_type(output_index)))
|
.Attr("T", BaseType(feed_node->output_type(id.output_index())))
|
||||||
.Attr("index", arg_index)
|
.Attr("index", arg_index)
|
||||||
.Attr(kFeedIdAttr, TensorIdToString(feed.id()))
|
.Attr(kFeedIdAttr, TensorIdToString(id))
|
||||||
.Attr(kShapeAttr, TensorShape(feed.shape()))
|
.Attr(kShapeAttr, TensorShape(feed.shape()))
|
||||||
.Attr(kDebugNameAttr, feed.name())
|
.Attr(kDebugNameAttr, feed.name())
|
||||||
.Finalize(graph, &arg_node));
|
.Finalize(graph, &arg_node));
|
||||||
|
|
||||||
// Collects out-edges from the feed node that have a matching edge index;
|
// Collects out-edges from the feed node that have a matching edge index;
|
||||||
// these will be replaced with edges from the arg node instead.
|
// these will be replaced with edges from the arg node instead. Also
|
||||||
|
// replaces all control edges from Placeholder feed nodes; similar code
|
||||||
|
// exists in subgraph::RewriteGraphForExecution.
|
||||||
|
// TODO(toddw): Why only replace control edges from Placeholder?
|
||||||
//
|
//
|
||||||
// We must collect the edges first and process them in a second pass, since
|
// We must collect the edges first and process them in a second pass, since
|
||||||
// removing the edge from the graph invalidates feed_node->out_edges.
|
// removing the edge from the graph invalidates feed_node->out_edges.
|
||||||
std::vector<const Edge*> feed_edges;
|
std::vector<const Edge*> feed_edges;
|
||||||
for (const Edge* edge : feed_node->out_edges()) {
|
for (const Edge* edge : feed_node->out_edges()) {
|
||||||
if (edge->src_output() == output_index) {
|
if (edge->src_output() == id.output_index() ||
|
||||||
|
(edge->src_output() == Graph::kControlSlot &&
|
||||||
|
feed_node->type_string() == "Placeholder")) {
|
||||||
feed_edges.push_back(edge);
|
feed_edges.push_back(edge);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const Edge* edge : feed_edges) {
|
for (const Edge* edge : feed_edges) {
|
||||||
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
|
if (edge->src_output() == id.output_index()) {
|
||||||
|
graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(edge->src_output(), Graph::kControlSlot);
|
||||||
|
graph->AddControlEdge(arg_node, edge->dst());
|
||||||
|
}
|
||||||
graph->RemoveEdge(edge);
|
graph->RemoveEdge(edge);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -164,16 +179,13 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
|
|||||||
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
|
// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
|
||||||
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
|
// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
|
||||||
// execution to know the input and output args for the generated function.
|
// execution to know the input and output args for the generated function.
|
||||||
Status RewriteAndPruneGraph(
|
Status RewriteAndPruneGraph(Graph* graph, const Config& config,
|
||||||
Graph* graph, const Config& config,
|
const MainFlags& flags) {
|
||||||
const std::unordered_map<string, string>& feed_remapping,
|
|
||||||
const MainFlags& flags) {
|
|
||||||
NodeMap node_map;
|
NodeMap node_map;
|
||||||
for (Node* n : graph->nodes()) {
|
for (Node* n : graph->nodes()) {
|
||||||
node_map[n->name()] = n;
|
node_map[n->name()] = n;
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
|
||||||
AddArgNodes(graph, node_map, config.feed(), feed_remapping));
|
|
||||||
std::unordered_set<const Node*> retval_nodes;
|
std::unordered_set<const Node*> retval_nodes;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||||
@ -371,28 +383,17 @@ Status InitGraph(const GraphDef& graph_def, const Config& config,
|
|||||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
|
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
|
||||||
std::unique_ptr<Graph> g(new Graph(flib_def));
|
std::unique_ptr<Graph> g(new Graph(flib_def));
|
||||||
|
|
||||||
// Replace references to fed tensors with references to newly added
|
GraphDef copy_def;
|
||||||
// placeholders.
|
|
||||||
GraphDef first_copy_def = graph_def;
|
|
||||||
|
|
||||||
// Maps from name:port of a feed to the name:port of the placeholder to use.
|
|
||||||
std::unordered_map<string, string> feed_remapping;
|
|
||||||
TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
|
|
||||||
&feed_remapping, &first_copy_def));
|
|
||||||
|
|
||||||
// Prune the GraphDef first so that unknown ops that we aren't compiling get
|
// Prune the GraphDef first so that unknown ops that we aren't compiling get
|
||||||
// filtered out.
|
// filtered out.
|
||||||
GraphDef second_copy_def;
|
TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, ©_def));
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
PruneGraphDefInto(config, first_copy_def, &second_copy_def));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
|
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(),
|
||||||
&second_copy_def, *g->op_registry(), 0 /*node_offset*/));
|
0 /*node_offset*/));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
|
|
||||||
second_copy_def, g.get()));
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
|
ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
|
||||||
|
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
|
||||||
*graph = std::move(g);
|
*graph = std::move(g);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -6,12 +6,21 @@ node {
|
|||||||
value {
|
value {
|
||||||
tensor {
|
tensor {
|
||||||
dtype: DT_INT32
|
dtype: DT_INT32
|
||||||
tensor_shape { dim { size: 1 } }
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
int_val: 1
|
int_val: 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr { key : "dtype" value { type: DT_INT32 } }
|
attr {
|
||||||
|
key : "dtype"
|
||||||
|
value {
|
||||||
|
type : DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
name : "y_const"
|
name : "y_const"
|
||||||
@ -21,37 +30,56 @@ node {
|
|||||||
value {
|
value {
|
||||||
tensor {
|
tensor {
|
||||||
dtype: DT_INT32
|
dtype: DT_INT32
|
||||||
tensor_shape { dim { size: 1 } }
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
int_val: 2
|
int_val: 2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr { key: "dtype" value { type: DT_INT32 } }
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
name : "x_y_sum"
|
name : "x_y_sum"
|
||||||
op : "Add"
|
op : "Add"
|
||||||
input : "x_const"
|
input : "x_const"
|
||||||
input : "y_const"
|
input : "y_const"
|
||||||
attr { key : "T" value { type: DT_INT32 } }
|
attr {
|
||||||
|
key : "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
name : "z"
|
name : "z"
|
||||||
op : "SomeUnknownOp"
|
op : "SomeUnknownOp"
|
||||||
input : "x_const"
|
input : "x_const"
|
||||||
}
|
attr {
|
||||||
node {
|
key : "T"
|
||||||
name : "z_identity"
|
value {
|
||||||
op : "Identity"
|
type: DT_INT32
|
||||||
input : "z:1"
|
}
|
||||||
attr { key : "T" value { type: DT_INT32 } }
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
name : "x_z_sum"
|
name : "x_z_sum"
|
||||||
op : "Add"
|
op : "Add"
|
||||||
input : "x_const"
|
input : "x_const"
|
||||||
input : "z_identity"
|
input : "z"
|
||||||
attr { key : "T" value { type: DT_INT32 } }
|
attr {
|
||||||
|
key : "T"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
versions {
|
versions {
|
||||||
producer: 15
|
producer: 15
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
# Text form of tensorflow.tfcompile.Config proto.
|
|
||||||
feed {
|
|
||||||
id { node_name: "x_const" }
|
|
||||||
shape {
|
|
||||||
dim { size: 1 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
feed {
|
|
||||||
id { node_name: "y_const" }
|
|
||||||
shape {
|
|
||||||
dim { size: 1 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
feed {
|
|
||||||
id { node_name: "z_identity"}
|
|
||||||
shape {
|
|
||||||
dim { size: 1 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fetch {
|
|
||||||
id { node_name: "x_y_sum" }
|
|
||||||
}
|
|
||||||
fetch {
|
|
||||||
id { node_name: "x_z_sum" }
|
|
||||||
}
|
|
@ -1,26 +0,0 @@
|
|||||||
# Text form of tensorflow.tfcompile.Config proto.
|
|
||||||
feed {
|
|
||||||
id { node_name: "x_const" }
|
|
||||||
shape {
|
|
||||||
dim { size: 1 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
feed {
|
|
||||||
id { node_name: "y_const" }
|
|
||||||
shape {
|
|
||||||
dim { size: 1 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
feed {
|
|
||||||
id { node_name: "z" output_index: 1}
|
|
||||||
shape {
|
|
||||||
dim { size: 1 }
|
|
||||||
}
|
|
||||||
type: DT_INT32
|
|
||||||
}
|
|
||||||
fetch {
|
|
||||||
id { node_name: "x_y_sum" }
|
|
||||||
}
|
|
||||||
fetch {
|
|
||||||
id { node_name: "x_z_sum" }
|
|
||||||
}
|
|
@ -7,7 +7,6 @@ option java_multiple_files = true;
|
|||||||
option java_package = "org.tensorflow.tfcompile";
|
option java_package = "org.tensorflow.tfcompile";
|
||||||
|
|
||||||
import "tensorflow/core/framework/tensor_shape.proto";
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
import "tensorflow/core/framework/types.proto";
|
|
||||||
|
|
||||||
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
|
// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
|
||||||
// index of a particular node in the graph. If the output of the named node
|
// index of a particular node in the graph. If the output of the named node
|
||||||
@ -24,12 +23,6 @@ message Feed {
|
|||||||
TensorId id = 1;
|
TensorId id = 1;
|
||||||
TensorShapeProto shape = 2;
|
TensorShapeProto shape = 2;
|
||||||
string name = 3; // Optional name for generated code.
|
string name = 3; // Optional name for generated code.
|
||||||
|
|
||||||
// Optional data type. This is not normally required, as the graph itself
|
|
||||||
// contains this information. However, if the node being fed is an op that
|
|
||||||
// is not linked into the tfcompile binary, then the type cannot be inferred
|
|
||||||
// from the node; in this case, the type should be set here.
|
|
||||||
DataType type = 4;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
// Fetch represents a single fetch tensor in the graph, which corresponds to an
|
||||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tfcompile {
|
namespace tfcompile {
|
||||||
@ -120,105 +119,17 @@ Status ValidateConfig(const Config& config) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AddPlaceholdersForFeeds(
|
|
||||||
const Config& config, const OpRegistryInterface* op_registry,
|
|
||||||
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
|
|
||||||
struct PlaceholderInfo {
|
|
||||||
const Feed* feed = nullptr; // point to Feed in <config>.
|
|
||||||
string placeholder_name;
|
|
||||||
DataType data_type = DT_INVALID;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Put each fed tensor into a map by name:port. A map is used for determinism
|
|
||||||
// when creating placeholders (genrules want deterministic output).
|
|
||||||
std::map<string, PlaceholderInfo> placeholder_info;
|
|
||||||
for (int i = 0; i < config.feed_size(); ++i) {
|
|
||||||
const Feed* feed = &config.feed(i);
|
|
||||||
const string name_port = TensorIdToString(feed->id());
|
|
||||||
auto& info = placeholder_info[name_port];
|
|
||||||
info.feed = feed;
|
|
||||||
info.placeholder_name = strings::StrCat(
|
|
||||||
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
|
|
||||||
(*feed_remapping)[name_port] = info.placeholder_name;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify node exists and determine data type.
|
|
||||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
|
||||||
for (int i = 0; i < graph_def->node_size(); ++i) {
|
|
||||||
name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
|
|
||||||
}
|
|
||||||
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
|
|
||||||
PlaceholderInfo& info = it->second;
|
|
||||||
const TensorId& feed_id = info.feed->id();
|
|
||||||
|
|
||||||
// Find the existing node and determine data type.
|
|
||||||
auto node_it = name_to_node.find(feed_id.node_name());
|
|
||||||
if (node_it == name_to_node.end()) {
|
|
||||||
return errors::NotFound("Can't find feed node: ",
|
|
||||||
TensorIdToString(feed_id));
|
|
||||||
}
|
|
||||||
const NodeDef* existing = node_it->second;
|
|
||||||
|
|
||||||
if (info.feed->type() != DT_INVALID) {
|
|
||||||
info.data_type = info.feed->type();
|
|
||||||
} else {
|
|
||||||
// Build the node in order to infer its type.
|
|
||||||
Graph g(op_registry);
|
|
||||||
Status status;
|
|
||||||
Node* feed_node = g.AddNode(*existing, &status);
|
|
||||||
TF_RETURN_IF_ERROR(status);
|
|
||||||
info.data_type =
|
|
||||||
BaseType(feed_node->output_type(info.feed->id().output_index()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create placeholders. Note that we could avoid creating a placeholder for
|
|
||||||
// feeds which are already placeholders, but we omit that to avoid more cases
|
|
||||||
// in this code.
|
|
||||||
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
|
|
||||||
const PlaceholderInfo& info = it->second;
|
|
||||||
NodeDef* d = graph_def->add_node();
|
|
||||||
d->set_name(info.placeholder_name);
|
|
||||||
d->set_op("PlaceholderV2");
|
|
||||||
auto& attr_map = *d->mutable_attr();
|
|
||||||
attr_map["dtype"].set_type(info.data_type);
|
|
||||||
*attr_map["shape"].mutable_shape() = info.feed->shape();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rewrite references to the fed tensors to refer to the placeholder.
|
|
||||||
for (int i = 0; i < graph_def->node_size(); ++i) {
|
|
||||||
NodeDef* node_def = graph_def->mutable_node(i);
|
|
||||||
for (int j = 0; j < node_def->input_size(); ++j) {
|
|
||||||
auto id = ParseTensorName(node_def->input(j));
|
|
||||||
auto it = placeholder_info.find(id.ToString());
|
|
||||||
if (it != placeholder_info.end()) {
|
|
||||||
node_def->set_input(j, it->second.placeholder_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||||
GraphDef* out) {
|
GraphDef* out) {
|
||||||
*out = in;
|
*out = in;
|
||||||
out->clear_node();
|
out->clear_node();
|
||||||
|
|
||||||
// Tensors needed for feeding.
|
|
||||||
std::set<std::pair<string, int>> feed_tensors;
|
|
||||||
for (const auto& feed_config : config.feed()) {
|
|
||||||
feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
|
|
||||||
feed_config.id().output_index()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Maps node name to reachability.
|
// Maps node name to reachability.
|
||||||
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
|
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
|
||||||
for (const NodeDef& node : in.node()) {
|
for (const NodeDef& node : in.node()) {
|
||||||
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
|
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Traverse.
|
|
||||||
std::queue<string> name_queue;
|
std::queue<string> name_queue;
|
||||||
for (int i = 0; i < config.fetch_size(); ++i) {
|
for (int i = 0; i < config.fetch_size(); ++i) {
|
||||||
name_queue.push(config.fetch(i).id().node_name());
|
name_queue.push(config.fetch(i).id().node_name());
|
||||||
@ -238,19 +149,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
|||||||
}
|
}
|
||||||
map_entry.first = true;
|
map_entry.first = true;
|
||||||
|
|
||||||
// Push input nodes of the currently visited node to name_queue.
|
|
||||||
for (const string& in_edge : map_entry.second->input()) {
|
for (const string& in_edge : map_entry.second->input()) {
|
||||||
auto id = ParseTensorName(in_edge);
|
name_queue.push(ParseTensorName(in_edge).first.ToString());
|
||||||
const string node_name = id.first.ToString();
|
|
||||||
if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
|
|
||||||
feed_tensors.end()) {
|
|
||||||
name_queue.push(node_name);
|
|
||||||
} else {
|
|
||||||
// The input tensor is from an edge that is being fed. Therefore,
|
|
||||||
// we skip recursing down that edge, to avoid requiring nodes that
|
|
||||||
// may not be needed (note that the input node may still be added
|
|
||||||
// to name_queue later if one of its output edges is not being fed).
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -265,9 +165,5 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
string TensorIdToString(const TensorId& id) {
|
|
||||||
return strings::StrCat(id.node_name(), ":", id.output_index());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,11 +16,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||||
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
#include "tensorflow/compiler/aot/tfcompile.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
|
||||||
@ -34,23 +31,11 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg);
|
|||||||
// ValidateConfig returns OK iff config is valid.
|
// ValidateConfig returns OK iff config is valid.
|
||||||
Status ValidateConfig(const Config& config);
|
Status ValidateConfig(const Config& config);
|
||||||
|
|
||||||
// Modifies <graph_def> to include placeholders for each fed tensor, and
|
|
||||||
// update references to the fed tensors to refer to the placeholders.
|
|
||||||
// The existing nodes referenced by the feeds are not removed or modified
|
|
||||||
// (except where their input edges are modified by the replacement of other
|
|
||||||
// feeds).
|
|
||||||
Status AddPlaceholdersForFeeds(
|
|
||||||
const Config& config, const OpRegistryInterface* op_registry,
|
|
||||||
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
|
|
||||||
|
|
||||||
// Returns in <out> a copy of <in>, pruned to only include fetches from
|
// Returns in <out> a copy of <in>, pruned to only include fetches from
|
||||||
// <config>.
|
// <config>.
|
||||||
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
Status PruneGraphDefInto(const Config& config, const GraphDef& in,
|
||||||
GraphDef* out);
|
GraphDef* out);
|
||||||
|
|
||||||
// Returns node:port for the given <id>.
|
|
||||||
string TensorIdToString(const TensorId& id);
|
|
||||||
|
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user