Add the option of including Shape, ShapeN, Size and Rank in the standard TensorFlow constant propagation pass, when the inputs to those Ops have sufficiently known static shape.
PiperOrigin-RevId: 163762750
This commit is contained in:
parent
8b1365bb40
commit
9e7875437f
@ -161,7 +161,7 @@ Status XlaCompiler::CompileFunction(
|
||||
opts.set_do_constant_folding(true);
|
||||
GraphOptimizer optimizer(opts);
|
||||
optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph);
|
||||
/*device=*/nullptr, &graph, /*shape_map=*/nullptr);
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -43,11 +43,182 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsConstantFoldable(const Node* n,
|
||||
const std::function<bool(const Node*)>& consider) {
|
||||
// Test to see if the Op is one that turns into a constant when its
|
||||
// inputs' shapes are known.
|
||||
bool IsShapeOp(const Node* n) {
|
||||
const auto& ts = n->type_string();
|
||||
return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size";
|
||||
}
|
||||
|
||||
// Reads the partially-known shape of each of n's inputs from shape_map, and
|
||||
// stores it to input_shapes. Returns false if any input does not have a shape
|
||||
// in shape_map.
|
||||
bool ReadPartialShapesFromShapeMap(
|
||||
const Node* n,
|
||||
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
|
||||
shape_map,
|
||||
std::vector<PartialTensorShape>* input_shapes) {
|
||||
CHECK(shape_map != nullptr);
|
||||
for (const Edge* in : n->in_edges()) {
|
||||
// Don't need to check if incoming control edges have known shapes.
|
||||
if (in->IsControlEdge()) continue;
|
||||
if (shape_map->count(in->src()) == 0) {
|
||||
// One of n's inputs doesn't have known shapes, so don't replace n.
|
||||
return false;
|
||||
}
|
||||
const auto& known_shape = shape_map->at(in->src());
|
||||
CHECK_GT(known_shape.size(), in->src_output());
|
||||
input_shapes->push_back(known_shape[in->src_output()]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// If all of n's inputs have fully-defined shapes, inserts those shapes as a
|
||||
// vector of Tensors in the shape_replacement_map.
|
||||
bool MaybeReplaceShapeOrShapeNOp(
|
||||
const Node* n, const std::vector<PartialTensorShape>& input_shapes,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
std::vector<Tensor> defined_shape;
|
||||
for (const auto& shape : input_shapes) {
|
||||
if (!shape.IsFullyDefined()) {
|
||||
return false;
|
||||
}
|
||||
const int rank = shape.dims();
|
||||
DataType op_type = n->output_type(0);
|
||||
Tensor t(op_type, TensorShape({rank}));
|
||||
if (op_type == DT_INT64) {
|
||||
auto vec = t.vec<int64>();
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
vec(i) = shape.dim_size(i);
|
||||
}
|
||||
} else {
|
||||
CHECK(op_type == DT_INT32);
|
||||
auto vec = t.vec<int32>();
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
if (shape.dim_size(i) > INT_MAX) {
|
||||
VLOG(1) << "Node " << n->name() << " has input shape dimension " << i
|
||||
<< " of " << shape.dim_size(i) << " but type INT32 "
|
||||
<< " so not replacing as constant: this will trigger a "
|
||||
"runtime error later.";
|
||||
return false;
|
||||
}
|
||||
vec(i) = static_cast<int32>(shape.dim_size(i));
|
||||
}
|
||||
}
|
||||
defined_shape.push_back(t);
|
||||
}
|
||||
// All the inputs had known shapes so we can replace the node by constants
|
||||
// later in the rewrite.
|
||||
shape_replacement_map->insert({n, defined_shape});
|
||||
return true;
|
||||
}
|
||||
|
||||
// If n's input has defined rank, inserts that rank as a Tensor in the
|
||||
// shape_replacement_map.
|
||||
bool MaybeReplaceRankOp(const Node* n,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
CHECK_EQ(input_shapes.size(), 1);
|
||||
if (input_shapes[0].unknown_rank()) {
|
||||
return false;
|
||||
}
|
||||
Tensor t(DT_INT32, TensorShape({}));
|
||||
t.scalar<int32>()() = input_shapes[0].dims();
|
||||
shape_replacement_map->insert({n, {t}});
|
||||
return true;
|
||||
}
|
||||
|
||||
// If n's input has defined size, inserts that size as a Tensor in the
|
||||
// shape_replacement_map.
|
||||
bool MaybeReplaceSizeOp(const Node* n,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
CHECK_EQ(input_shapes.size(), 1);
|
||||
if (!input_shapes[0].IsFullyDefined()) {
|
||||
return false;
|
||||
}
|
||||
DataType op_type = n->output_type(0);
|
||||
Tensor t(op_type, TensorShape({}));
|
||||
int64 size = input_shapes[0].num_elements();
|
||||
if (op_type == DT_INT64) {
|
||||
t.scalar<int64>()() = size;
|
||||
} else {
|
||||
CHECK(op_type == DT_INT32);
|
||||
if (size > INT_MAX) {
|
||||
VLOG(1) << "Node " << n->name() << " has input shape size " << size
|
||||
<< " but type INT32 "
|
||||
<< " so not replacing as constant: this will trigger a runtime "
|
||||
"error later.";
|
||||
return false;
|
||||
}
|
||||
t.scalar<int32>()() = static_cast<int32>(size);
|
||||
}
|
||||
shape_replacement_map->insert({n, {t}});
|
||||
return true;
|
||||
}
|
||||
|
||||
// If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their
|
||||
// shapes specified in shape_map, then adds to shape_replacement_map a mapping
|
||||
// from n to a vector of Tensors, where Tensor k is the (statically known) value
|
||||
// on n's kth output edge. shape_replacement_map has an entry for n iff
|
||||
// MaybeReplaceShapeOp returns true, so it's valid to use
|
||||
// shape_replacement_map->count(n) as a test to see if n is a shape op that can
|
||||
// be replaced.
|
||||
bool MaybeReplaceShapeOp(
|
||||
const Node* n,
|
||||
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
|
||||
shape_map,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
if (shape_map == nullptr || !IsShapeOp(n)) {
|
||||
return false;
|
||||
}
|
||||
// input_shapes will contain the shapes of each of n's inputs.
|
||||
std::vector<PartialTensorShape> input_shapes;
|
||||
if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) {
|
||||
return false;
|
||||
}
|
||||
const auto& ts = n->type_string();
|
||||
if (ts == "Shape" || ts == "ShapeN") {
|
||||
if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) {
|
||||
return false;
|
||||
}
|
||||
} else if (ts == "Rank") {
|
||||
if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(ts, "Size");
|
||||
if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if n can be evaluated as constant. shape_map maps from
|
||||
// nodes to the partially-known shapes of their outputs. consider if
|
||||
// non-null returns a bool indicating whether a given (non-Const,
|
||||
// non-Shape) node is eligible to be
|
||||
// constant-propagated. shape_replacement_map is filled in with a
|
||||
// vector of constant output tensors for constant-foldable shape nodes
|
||||
// (Shape, ShapeN, Size, or Rank).
|
||||
bool IsConstantFoldable(
|
||||
const Node* n,
|
||||
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
|
||||
shape_map,
|
||||
const std::function<bool(const Node*)>& consider,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
if (n->IsConstant()) {
|
||||
return true;
|
||||
}
|
||||
if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) {
|
||||
return true;
|
||||
}
|
||||
if (n->op_def().is_stateful()) {
|
||||
return false;
|
||||
}
|
||||
@ -82,56 +253,81 @@ bool IsConstantFoldable(const Node* n,
|
||||
return true;
|
||||
}
|
||||
|
||||
// If n is eligible for constant-folding, adds it to nodes, and places its
|
||||
// control dependencies and those transitively of its constant-foldable inputs
|
||||
// into constant_control_deps. If n is a constant-foldable shape node (Shape,
|
||||
// ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map.
|
||||
void ConsiderConstantFoldableNode(
|
||||
Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes,
|
||||
std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
|
||||
bool* internal_node_inserted) {
|
||||
if (IsConstantFoldable(n, opts.shape_map, opts.consider,
|
||||
shape_replacement_map)) {
|
||||
// A node is constant provided all of its non-control incoming Tensors come
|
||||
// from constant nodes, or it's a shape Op with statically known inputs in
|
||||
// which case it is placed in shape_replacement_map.
|
||||
//
|
||||
// We allow control dependencies from non-constant nodes to constant nodes,
|
||||
// but to preserve the graph structure we must transfer the control
|
||||
// dependency onto any constant replacement.
|
||||
bool all_parents_constant = true;
|
||||
for (const Edge* in : n->in_edges()) {
|
||||
// Allows non-constant -> constant control edges.
|
||||
if (!in->IsControlEdge() &&
|
||||
constant_control_deps->count(in->src()) == 0) {
|
||||
all_parents_constant = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_parents_constant || shape_replacement_map->count(n) != 0) {
|
||||
gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (constant_control_deps->count(e->src()) == 0) {
|
||||
// This branch is taken if the incoming edge is a control dependency,
|
||||
// in which case we want to add it to the dependencies being
|
||||
// accumulated for this node, or the incoming edge is not
|
||||
// constant. The latter may happen when n is a shape node and the
|
||||
// source has known shape. In that case add a control dependency from
|
||||
// the source node, since there was previously a data dependency and
|
||||
// we want to preserve sequencing constraints.
|
||||
if (!e->src()->IsSource()) {
|
||||
control_deps.insert(e->src());
|
||||
}
|
||||
} else {
|
||||
// If the parent has been accumulating control dependencies, add all
|
||||
// of its transitive control deps.
|
||||
const gtl::FlatSet<Node*>& parent_deps =
|
||||
(*constant_control_deps)[e->src()];
|
||||
control_deps.insert(parent_deps.begin(), parent_deps.end());
|
||||
}
|
||||
}
|
||||
nodes->push_back(n);
|
||||
if (!n->IsConstant()) {
|
||||
*internal_node_inserted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the constant foldable nodes in `nodes` in topological order.
|
||||
// Populates `constant_control_deps` with the non-constant control dependencies
|
||||
// of each constant node.
|
||||
void FindConstantFoldableNodes(
|
||||
const Graph* graph, ConstantFoldingOptions opts, std::vector<Node*>* nodes,
|
||||
std::unordered_map<const Node*, gtl::FlatSet<Node*>>*
|
||||
constant_control_deps) {
|
||||
const Graph* graph, const ConstantFoldingOptions& opts,
|
||||
std::vector<Node*>* nodes,
|
||||
std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
bool internal_node_inserted = false;
|
||||
// Walk the nodes in data flow order
|
||||
ReverseDFS(
|
||||
*graph, nullptr,
|
||||
[nodes, constant_control_deps, &internal_node_inserted, opts](Node* n) {
|
||||
if (IsConstantFoldable(n, opts.consider)) {
|
||||
// A node is constant provided all of its non-control
|
||||
// incoming Tensors come from constant nodes.
|
||||
//
|
||||
// We allow control dependencies from non-constant nodes to constant
|
||||
// nodes, but to preserve the graph structure we must transfer the
|
||||
// control dependency onto any constant replacement.
|
||||
bool all_parents_constant = true;
|
||||
for (const Edge* in : n->in_edges()) {
|
||||
// Allows non-constant -> constant control edges.
|
||||
if (!in->IsControlEdge() &&
|
||||
constant_control_deps->count(in->src()) == 0) {
|
||||
all_parents_constant = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_parents_constant) {
|
||||
gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (constant_control_deps->count(e->src()) == 0) {
|
||||
if (!e->src()->IsSource()) {
|
||||
control_deps.insert(e->src());
|
||||
}
|
||||
} else {
|
||||
// If the parent is constant, add all of its transitive control
|
||||
// deps.
|
||||
const gtl::FlatSet<Node*>& parent_deps =
|
||||
(*constant_control_deps)[e->src()];
|
||||
control_deps.insert(parent_deps.begin(), parent_deps.end());
|
||||
}
|
||||
}
|
||||
nodes->push_back(n);
|
||||
if (!n->IsConstant()) {
|
||||
internal_node_inserted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
// Walk the nodes in data flow order.
|
||||
ReverseDFS(*graph, nullptr,
|
||||
[nodes, constant_control_deps, shape_replacement_map,
|
||||
&internal_node_inserted, &opts](Node* n) {
|
||||
ConsiderConstantFoldableNode(
|
||||
n, opts, nodes, constant_control_deps, shape_replacement_map,
|
||||
&internal_node_inserted);
|
||||
});
|
||||
// If we have inserted just leaf level nodes, then there is nothing to fold.
|
||||
if (!internal_node_inserted) {
|
||||
nodes->clear();
|
||||
@ -141,31 +337,93 @@ void FindConstantFoldableNodes(
|
||||
|
||||
typedef std::pair<Node*, int> NodeAndOutput;
|
||||
|
||||
int64 UniqueConstantId() {
|
||||
static std::atomic_int_fast64_t id;
|
||||
return id.fetch_add(1);
|
||||
}
|
||||
|
||||
// Adds n to constant_graph which is being built up for subsequent evaluation of
|
||||
// constant propagation. node_map is the mapping of nodes in the original graph
|
||||
// to nodes in the constant graph. The value of an entry in node_map is a vector
|
||||
// of nodes because a ShapeN node in the original graph is replaced by a vector
|
||||
// of Constant nodes in the constant graph.
|
||||
void AddNodeToConstantGraph(
|
||||
Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map,
|
||||
Graph* constant_graph) {
|
||||
std::vector<Node*>& added = (*node_map)[n];
|
||||
added.push_back(constant_graph->CopyNode(n));
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
// Don't copy control edges to the constant graph.
|
||||
if (!in_edge->IsControlEdge()) {
|
||||
Node* in = in_edge->src();
|
||||
auto it = node_map->find(in);
|
||||
CHECK(it != node_map->end())
|
||||
<< n->DebugString() << " <-" << in->DebugString();
|
||||
if (it->second.size() == 1) {
|
||||
constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0],
|
||||
in_edge->dst_input());
|
||||
} else {
|
||||
// The original source node had multiple outputs and was replaced by a
|
||||
// vector of constants, so the edge comes from the 0th output of the kth
|
||||
// added constant, rather than the kth output of the added node as in
|
||||
// the standard case above.
|
||||
constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0],
|
||||
in_edge->dst_input());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replaces constant-foldable shape node n by a vector of constants in
|
||||
// constant_graph, which is being built up for subsequent evaluation of constant
|
||||
// propagation. node_map is the mapping of nodes in the original graph to nodes
|
||||
// in the constant graph. The value of an entry in node_map is a vector of nodes
|
||||
// because a ShapeN node in the original graph is replaced by a vector of
|
||||
// Constant nodes in the constant graph.
|
||||
void AddShapeNodeToConstantGraph(
|
||||
Node* n,
|
||||
const std::unordered_map<const Node*, std::vector<Tensor>>&
|
||||
shape_replacement_map,
|
||||
std::unordered_map<Node*, std::vector<Node*>>* node_map,
|
||||
Graph* constant_graph) {
|
||||
std::vector<Node*>& added = (*node_map)[n];
|
||||
const string& node_name = n->name();
|
||||
for (const Tensor& t : shape_replacement_map.at(n)) {
|
||||
auto builder =
|
||||
NodeDefBuilder(strings::StrCat(constant_graph->NewName(node_name),
|
||||
"__cf__", UniqueConstantId()),
|
||||
"Const")
|
||||
.Attr("dtype", t.dtype())
|
||||
.Attr("value", t);
|
||||
NodeDef def;
|
||||
CHECK(builder.Finalize(&def).ok());
|
||||
Node* constant_node;
|
||||
CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok());
|
||||
added.push_back(constant_node);
|
||||
}
|
||||
// Don't copy incoming edges to shape nodes that are being replaced.
|
||||
}
|
||||
|
||||
// Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
|
||||
// will contain copies of the nodes in 'nodes'. In addition, if there is an edge
|
||||
// going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
|
||||
// 'nodes', then 'tensors_to_fetch' will contain the mapping from the
|
||||
// corresponding copy of 'n' and the edge number in 'g' to 'n'.
|
||||
Graph* GetConstantGraph(const Graph* orig_graph,
|
||||
const std::vector<Node*>& nodes,
|
||||
std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
|
||||
Graph* GetConstantGraph(
|
||||
const Graph* orig_graph, const std::vector<Node*>& nodes,
|
||||
const std::unordered_map<const Node*, std::vector<Tensor>>&
|
||||
shape_replacement_map,
|
||||
std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
|
||||
Graph* constant_graph = new Graph(orig_graph->op_registry());
|
||||
std::unordered_map<Node*, Node*> node_map;
|
||||
node_map[orig_graph->source_node()] = constant_graph->source_node();
|
||||
node_map[orig_graph->sink_node()] = constant_graph->sink_node();
|
||||
std::unordered_map<Node*, std::vector<Node*>> node_map;
|
||||
node_map[orig_graph->source_node()] = {constant_graph->source_node()};
|
||||
node_map[orig_graph->sink_node()] = {constant_graph->sink_node()};
|
||||
for (Node* n : nodes) {
|
||||
Node* added = constant_graph->CopyNode(n);
|
||||
node_map[n] = added;
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
// Don't copy control edges to the constant graph.
|
||||
if (!in_edge->IsControlEdge()) {
|
||||
Node* in = in_edge->src();
|
||||
auto it = node_map.find(in);
|
||||
CHECK(it != node_map.end())
|
||||
<< n->DebugString() << " <-" << in->DebugString();
|
||||
constant_graph->AddEdge(it->second, in_edge->src_output(), added,
|
||||
in_edge->dst_input());
|
||||
}
|
||||
if (shape_replacement_map.count(n) == 0) {
|
||||
AddNodeToConstantGraph(n, &node_map, constant_graph);
|
||||
} else {
|
||||
AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
|
||||
constant_graph);
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,8 +431,19 @@ Graph* GetConstantGraph(const Graph* orig_graph,
|
||||
for (const Edge* out_edge : added_nodes.first->out_edges()) {
|
||||
if (node_map.count(out_edge->dst()) == 0) {
|
||||
if (out_edge->IsControlEdge()) continue;
|
||||
tensors_to_fetch->insert(
|
||||
{{added_nodes.second, out_edge->src_output()}, added_nodes.first});
|
||||
if (added_nodes.second.size() == 1) {
|
||||
tensors_to_fetch->insert(
|
||||
{{added_nodes.second[0], out_edge->src_output()},
|
||||
added_nodes.first});
|
||||
} else {
|
||||
// The node had multiple outputs and was replaced by a
|
||||
// vector of constants, so the NodeAndOutput is the 0th
|
||||
// output of the kth added constant, rather than the kth
|
||||
// output of the added node as in the standard case above.
|
||||
tensors_to_fetch->insert(
|
||||
{{added_nodes.second[out_edge->src_output()], 0},
|
||||
added_nodes.first});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -182,11 +451,6 @@ Graph* GetConstantGraph(const Graph* orig_graph,
|
||||
return constant_graph;
|
||||
}
|
||||
|
||||
int64 UniqueConstantId() {
|
||||
static std::atomic_int_fast64_t id;
|
||||
return id.fetch_add(1);
|
||||
}
|
||||
|
||||
// Replaces the identified Tensor in 'graph' by a 'Const' node with
|
||||
// the value supplied in 'constant'. 'partition_device', if non-null
|
||||
// is the device where the graph executes. Returns true if the
|
||||
@ -291,8 +555,9 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
|
||||
|
||||
std::vector<Node*> constant_foldable_nodes;
|
||||
std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
|
||||
std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map;
|
||||
FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
|
||||
&constant_control_deps);
|
||||
&constant_control_deps, &shape_replacement_map);
|
||||
if (constant_foldable_nodes.empty()) {
|
||||
VLOG(1) << "No constant foldable nodes found";
|
||||
*was_mutated = false;
|
||||
@ -302,7 +567,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
|
||||
|
||||
std::map<NodeAndOutput, Node*> tensors_to_fetch;
|
||||
std::unique_ptr<Graph> constant_graph(
|
||||
GetConstantGraph(graph, constant_foldable_nodes, &tensors_to_fetch));
|
||||
GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
|
||||
&tensors_to_fetch));
|
||||
DumpGraph("Constant graph", constant_graph.get());
|
||||
|
||||
if (tensors_to_fetch.empty()) {
|
||||
@ -337,7 +603,6 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Could not fetch constants: " << s;
|
||||
*was_mutated = false;
|
||||
// This is not an error, so return the status as OK.
|
||||
return s;
|
||||
}
|
||||
|
||||
|
@ -29,6 +29,11 @@ struct ConstantFoldingOptions {
|
||||
// If "consider" is not a nullptr, then only constant fold a node "n" if
|
||||
// consider(n) returns true.
|
||||
std::function<bool(const Node*)> consider = nullptr;
|
||||
// If shape_map is not a nullptr, it is a map from node n to a
|
||||
// vector of the (potentially partially-known) shapes of its
|
||||
// outputs.
|
||||
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
|
||||
shape_map; // not owned
|
||||
};
|
||||
|
||||
// Perform constant folding optimization on "graph".
|
||||
|
@ -363,6 +363,197 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, SimpleShapeKnown) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
|
||||
"sender", 0, "receiver");
|
||||
auto shape = ops::Shape(s.WithOpName("shape"), recv0);
|
||||
Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
|
||||
"sender", 0, "receiver");
|
||||
auto shape_n = ops::ShapeN(s.WithOpName("shape_n"), {recv0, recv1});
|
||||
auto rank = ops::Rank(s.WithOpName("rank"), recv0);
|
||||
auto size = ops::Size(s.WithOpName("size"), recv1);
|
||||
auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
|
||||
0, "receiver");
|
||||
auto c = ops::Const<int>(s.WithControlDependencies(recv2), 3);
|
||||
auto add0 = ops::Add(s.WithControlDependencies(c), rank, size);
|
||||
auto add1 = ops::Add(s, shape, shape_n[0]);
|
||||
auto add2 = ops::Add(s, shape_n[1], shape_n[1]);
|
||||
auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
|
||||
"receiver");
|
||||
auto send1 = ops::_Send(s.WithOpName("send1"), add1, "send1", "sender", 0,
|
||||
"receiver");
|
||||
auto send2 = ops::_Send(s.WithOpName("send2"), add2, "send2", "sender", 0,
|
||||
"receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
|
||||
Node* recv0 = orig_index.at("recv0");
|
||||
Node* recv1 = orig_index.at("recv1");
|
||||
PartialTensorShape ps0;
|
||||
int r0_dims[] = {1, 2};
|
||||
TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
|
||||
PartialTensorShape ps1;
|
||||
int r1_dims[] = {2, 3, 4};
|
||||
TF_EXPECT_OK(PartialTensorShape::MakePartialShape<int>(r1_dims, 3, &ps1));
|
||||
std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
|
||||
map[recv0].push_back(ps0);
|
||||
map[recv1].push_back(ps1);
|
||||
ConstantFoldingOptions opts;
|
||||
opts.shape_map = ↦
|
||||
bool was_mutated;
|
||||
TF_EXPECT_OK(
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* recv2 = index.at("recv2");
|
||||
Node* send0 = index.at("send0");
|
||||
Node* send1 = index.at("send1");
|
||||
Node* send2 = index.at("send2");
|
||||
|
||||
ASSERT_EQ(1, send0->num_inputs());
|
||||
Node* cf0 = *(send0->in_nodes().begin());
|
||||
ExpectNodeEqual<int>(cf0, {26}, {});
|
||||
|
||||
ASSERT_EQ(1, send1->num_inputs());
|
||||
Node* cf1 = *(send1->in_nodes().begin());
|
||||
ExpectNodeEqual<int>(cf1, {2, 4}, {2});
|
||||
|
||||
ASSERT_EQ(1, send2->num_inputs());
|
||||
Node* cf2 = *(send2->in_nodes().begin());
|
||||
ExpectNodeEqual<int>(cf2, {4, 6, 8}, {3});
|
||||
|
||||
ASSERT_EQ(3, cf0->in_edges().size());
|
||||
for (const Edge* e : cf0->in_edges()) {
|
||||
EXPECT_TRUE(e->IsControlEdge());
|
||||
EXPECT_TRUE(e->src() == recv0 || e->src() == recv1 || e->src() == recv2)
|
||||
<< e->src()->name();
|
||||
}
|
||||
|
||||
ASSERT_EQ(2, cf1->in_edges().size());
|
||||
for (const Edge* e : cf1->in_edges()) {
|
||||
EXPECT_TRUE(e->IsControlEdge());
|
||||
EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
|
||||
}
|
||||
|
||||
ASSERT_EQ(2, cf2->in_edges().size());
|
||||
for (const Edge* e : cf2->in_edges()) {
|
||||
EXPECT_TRUE(e->IsControlEdge());
|
||||
EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, PartialShape) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0",
|
||||
"sender", 0, "receiver");
|
||||
Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1",
|
||||
"sender", 0, "receiver");
|
||||
auto shape = ops::Shape(s.WithOpName("shape"), recv0);
|
||||
auto rank0 = ops::Rank(s.WithOpName("rank0"), recv0);
|
||||
auto rank1 = ops::Rank(s.WithOpName("rank1"), recv1);
|
||||
auto size = ops::Size(s.WithOpName("size"), recv0);
|
||||
auto send0 = ops::_Send(s.WithOpName("send0"), rank0, "send0", "sender", 0,
|
||||
"receiver");
|
||||
auto send1 = ops::_Send(s.WithOpName("send1"), shape, "send1", "sender", 0,
|
||||
"receiver");
|
||||
auto send2 = ops::_Send(s.WithOpName("send2"), size, "send2", "sender", 0,
|
||||
"receiver");
|
||||
auto send3 = ops::_Send(s.WithOpName("send3"), rank1, "send3", "sender", 0,
|
||||
"receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
|
||||
Node* recv0 = orig_index.at("recv0");
|
||||
Node* recv1 = orig_index.at("recv1");
|
||||
PartialTensorShape ps0;
|
||||
int r0_dims[] = {-1, -1};
|
||||
TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0));
|
||||
PartialTensorShape ps1;
|
||||
std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
|
||||
map[recv0].push_back(ps0);
|
||||
map[recv1].push_back(ps1);
|
||||
ConstantFoldingOptions opts;
|
||||
opts.shape_map = ↦
|
||||
bool was_mutated;
|
||||
TF_EXPECT_OK(
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* shape = index.at("shape");
|
||||
Node* size = index.at("size");
|
||||
Node* rank1 = index.at("rank1");
|
||||
Node* send0 = index.at("send0");
|
||||
Node* send1 = index.at("send1");
|
||||
Node* send2 = index.at("send2");
|
||||
Node* send3 = index.at("send3");
|
||||
|
||||
ASSERT_EQ(1, send0->num_inputs());
|
||||
Node* cf0 = *(send0->in_nodes().begin());
|
||||
ExpectNodeEqual<int>(cf0, {2}, {});
|
||||
|
||||
ASSERT_EQ(1, send1->num_inputs());
|
||||
Node* ncf1 = *(send1->in_nodes().begin());
|
||||
EXPECT_EQ(ncf1, shape);
|
||||
|
||||
ASSERT_EQ(1, send2->num_inputs());
|
||||
Node* ncf2 = *(send2->in_nodes().begin());
|
||||
EXPECT_EQ(ncf2, size);
|
||||
|
||||
ASSERT_EQ(1, send3->num_inputs());
|
||||
Node* ncf3 = *(send3->in_nodes().begin());
|
||||
EXPECT_EQ(ncf3, rank1);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, ConstShapeKnown) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", "sender",
|
||||
0, "receiver");
|
||||
auto c0 =
|
||||
ops::Const<int>(s.WithOpName("c0").WithControlDependencies(recv0), 1);
|
||||
auto rank = ops::Rank(s.WithOpName("rank"), c0);
|
||||
auto add0 = ops::Add(s, rank, rank);
|
||||
auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0,
|
||||
"receiver");
|
||||
TF_ASSERT_OK(s.ToGraph(&g));
|
||||
}
|
||||
std::unordered_map<string, Node*> orig_index = NodeNameIndex(g);
|
||||
Node* c0 = orig_index.at("c0");
|
||||
PartialTensorShape ps0;
|
||||
int c0_dims[] = {};
|
||||
TF_EXPECT_OK(PartialTensorShape::MakePartialShape(c0_dims, 0, &ps0));
|
||||
std::unordered_map<const Node*, std::vector<PartialTensorShape>> map;
|
||||
map[c0].push_back(ps0);
|
||||
ConstantFoldingOptions opts;
|
||||
opts.shape_map = ↦
|
||||
bool was_mutated;
|
||||
TF_EXPECT_OK(
|
||||
ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated));
|
||||
EXPECT_TRUE(was_mutated);
|
||||
|
||||
std::unordered_map<string, Node*> index = NodeNameIndex(g);
|
||||
Node* recv0 = index.at("recv0");
|
||||
Node* send0 = index.at("send0");
|
||||
|
||||
ASSERT_EQ(1, send0->num_inputs());
|
||||
Node* cf0 = *(send0->in_nodes().begin());
|
||||
ExpectNodeEqual<int>(cf0, {0}, {});
|
||||
|
||||
ASSERT_EQ(1, cf0->in_edges().size());
|
||||
for (const Edge* e : cf0->in_edges()) {
|
||||
EXPECT_TRUE(e->IsControlEdge());
|
||||
EXPECT_TRUE(e->src() == recv0) << e->src()->name();
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const char kTestMemRegionName[] = "test://test";
|
||||
|
@ -1194,7 +1194,8 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
};
|
||||
params.node_outputs_cb = node_outputs_callback_;
|
||||
|
||||
optimizer.Optimize(lib, options_.env, device, &iter->second);
|
||||
optimizer.Optimize(lib, options_.env, device, &iter->second,
|
||||
/*shape_map=*/nullptr);
|
||||
|
||||
// EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
|
||||
if (!options.debug_options.debug_tensor_watch_opts().empty()) {
|
||||
|
@ -461,7 +461,7 @@ void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
|
||||
opts.set_do_function_inlining(true);
|
||||
opts.set_do_constant_folding(true);
|
||||
GraphOptimizer optimizer(opts);
|
||||
optimizer.Optimize(lib, lib->env(), lib->device(), g);
|
||||
optimizer.Optimize(lib, lib->env(), lib->device(), g, /*shape_map=*/nullptr);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
|
||||
@ -470,7 +470,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
|
||||
std::unique_ptr<Graph> g(new Graph(lib_def_));
|
||||
CopyGraph(*fbody->graph, g.get());
|
||||
|
||||
optimizer_.Optimize(this, env(), device(), &g);
|
||||
optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
|
||||
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
|
||||
device()->name(), g.get()));
|
||||
|
||||
|
@ -33,8 +33,11 @@ GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
|
||||
|
||||
GraphOptimizer::~GraphOptimizer() {}
|
||||
|
||||
void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
||||
Device* device, std::unique_ptr<Graph>* graph) {
|
||||
void GraphOptimizer::Optimize(
|
||||
FunctionLibraryRuntime* runtime, Env* env, Device* device,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
|
||||
shape_map) {
|
||||
Graph* g = graph->get();
|
||||
DumpGraph("Initial", g);
|
||||
|
||||
@ -57,6 +60,7 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
||||
|
||||
if (opts_.do_constant_folding()) {
|
||||
ConstantFoldingOptions cf_opts;
|
||||
cf_opts.shape_map = shape_map;
|
||||
bool was_mutated;
|
||||
ConstantFold(cf_opts, runtime, env, device, g, &was_mutated)
|
||||
.IgnoreError();
|
||||
|
@ -30,12 +30,17 @@ class GraphOptimizer {
|
||||
~GraphOptimizer();
|
||||
|
||||
// Applies optimization passes specified in 'opts' to 'graph'.
|
||||
// Maybe replace *graph with a new graph object.
|
||||
// 'device' is device on which the 'graph' will execute. It's passed to the
|
||||
// optimizers so that they can respect constraints if any, that should be
|
||||
// respected.
|
||||
void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device,
|
||||
std::unique_ptr<Graph>* graph);
|
||||
// Maybe replace *graph with a new graph object. 'device' is device
|
||||
// on which the 'graph' will execute. It's passed to the optimizers
|
||||
// so that they can respect constraints if any, that should be
|
||||
// respected. If shape_map is not null it maps from nodes in graph
|
||||
// to partially-known shapes of their outputs, and may be used,
|
||||
// e.g., in the constant folding pass.
|
||||
void Optimize(
|
||||
FunctionLibraryRuntime* runtime, Env* env, Device* device,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
const std::unordered_map<const Node*, std::vector<PartialTensorShape>>*
|
||||
shape_map);
|
||||
|
||||
private:
|
||||
OptimizerOptions opts_;
|
||||
|
@ -244,7 +244,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
}
|
||||
};
|
||||
|
||||
optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph);
|
||||
optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
|
||||
/*shape_map=*/nullptr);
|
||||
|
||||
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph.
|
||||
if (!debug_options.debug_tensor_watch_opts().empty()) {
|
||||
|
@ -127,7 +127,8 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def,
|
||||
|
||||
// Optimize the graph.
|
||||
GraphOptimizer optimizer(*optimizer_opts);
|
||||
optimizer.Optimize(flib.get(), env, devices[0], &graphptr);
|
||||
optimizer.Optimize(flib.get(), env, devices[0], &graphptr,
|
||||
/*shape_map=*/nullptr);
|
||||
graphptr->ToGraphDef(output_graph_def);
|
||||
|
||||
return Status::OK();
|
||||
|
Loading…
Reference in New Issue
Block a user