Fixes for outside compilation.
1. If a Const node is in outside compilation but it also has non outside compilation users, copy it as a non outside compilation (because in some cases the non outside compilation users require compile time constant input). 2. In shape inference, replace VariableShape node with Const node if the variable shape is already known. Otherwise the VariableShape node might become an input to TPU computation, while TPU computation requires that input to be compile time constant. PiperOrigin-RevId: 288049332 Change-Id: I1bff62ef044781ca04d1a453e24cd0efa7d532b3
This commit is contained in:
parent
e51a2086d5
commit
e450050b69
|
@ -500,6 +500,7 @@ cc_library(
|
|||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -2130,6 +2130,53 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CopyOutsideCompilationConstNodes(
|
||||
Graph* g, const string& outside_compilation_attr_name) {
|
||||
for (Node* n : g->op_nodes()) {
|
||||
if (!n->IsConstant() ||
|
||||
!HasNodeAttr(n->def(), outside_compilation_attr_name)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<const Edge*> out_edges(n->out_edges().begin(),
|
||||
n->out_edges().end());
|
||||
bool has_non_oc_output = false;
|
||||
for (const Edge* e : out_edges) {
|
||||
if (!e->IsControlEdge() &&
|
||||
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
|
||||
has_non_oc_output = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_non_oc_output) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeDef copy_def = n->def();
|
||||
copy_def.set_name(g->NewName(n->name()));
|
||||
copy_def.mutable_attr()->erase(outside_compilation_attr_name);
|
||||
Status s;
|
||||
Node* copy_node = g->AddNode(copy_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
g->AddControlEdge(e->src(), copy_node);
|
||||
}
|
||||
}
|
||||
for (const Edge* e : out_edges) {
|
||||
if (!e->IsControlEdge() &&
|
||||
!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
|
||||
Node* dst = e->dst();
|
||||
int dst_input = e->dst_input();
|
||||
g->RemoveEdge(e);
|
||||
g->AddEdge(copy_node, 0, dst, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RewriteOutsideCompilationSubgraphFn::operator()(
|
||||
|
@ -2279,6 +2326,10 @@ Status ExtractOutsideCompilationForFunction(
|
|||
std::vector<string> outside_compilation_host_graphs;
|
||||
std::vector<string> shape_inference_graphs_to_rewrite;
|
||||
if (*has_outside_compilation) {
|
||||
// Copy outside compilation Const nodes with non outside compilation users.
|
||||
TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
|
||||
fbody->graph, outside_compilation_attr_name));
|
||||
|
||||
// Find dependencies between outside compilation clusters.
|
||||
TF_ASSIGN_OR_RETURN(auto cluster_deps,
|
||||
OutsideCompilationClusterDependencies(
|
||||
|
|
|
@ -17,7 +17,10 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
|
@ -39,7 +42,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
|
|||
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
|
||||
}
|
||||
|
||||
Status PropagateShapes(const Graph& graph,
|
||||
Status PropagateShapes(Graph* graph,
|
||||
const std::map<int, InferredShape>& arg_shapes,
|
||||
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
|
||||
ShapeRefiner* shape_refiner) {
|
||||
|
@ -54,7 +57,7 @@ Status PropagateShapes(const Graph& graph,
|
|||
// shapes.
|
||||
// TODO(phawkins): handle cyclic graphs.
|
||||
std::vector<Node*> order;
|
||||
GetReversePostOrder(graph, &order);
|
||||
GetReversePostOrder(*graph, &order);
|
||||
|
||||
for (Node* n : order) {
|
||||
// Ignore the status returned by the shape_refiner. We want the best effort
|
||||
|
@ -99,6 +102,67 @@ Status PropagateShapes(const Graph& graph,
|
|||
}
|
||||
}
|
||||
|
||||
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
|
||||
// They won't be constant-folded because TensorFlow constant folding does
|
||||
// not handle Enter nodes (and thus does not handle any nodes after Enter
|
||||
// nodes). We try to replace such VariableShape nodes with Const nodes here.
|
||||
if (n->type_string() == "VariableShape") {
|
||||
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
|
||||
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
|
||||
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
|
||||
shape_inference::ShapeHandle handle =
|
||||
handle_shapes_and_types->at(0).shape;
|
||||
TensorShapeProto shape_proto;
|
||||
context->ShapeHandleToProto(handle, &shape_proto);
|
||||
if (!shape_proto.unknown_rank()) {
|
||||
NodeDef const_def;
|
||||
const_def.set_op("Const");
|
||||
Node* var_node;
|
||||
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
|
||||
const_def.set_name(
|
||||
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
|
||||
DataType dtype = n->output_type(0);
|
||||
AddNodeAttr("dtype", dtype, &const_def);
|
||||
TensorProto value;
|
||||
value.set_dtype(dtype);
|
||||
value.mutable_tensor_shape()->add_dim()->set_size(
|
||||
shape_proto.dim_size());
|
||||
for (const auto& dim : shape_proto.dim()) {
|
||||
if (dtype == DT_INT32) {
|
||||
value.add_int_val(dim.size());
|
||||
} else {
|
||||
value.add_int64_val(dim.size());
|
||||
}
|
||||
}
|
||||
AddNodeAttr("value", value, &const_def);
|
||||
for (auto const& attr : n->attrs()) {
|
||||
if (*attr.first.begin() == '_') {
|
||||
AddNodeAttr(attr.first, attr.second, &const_def);
|
||||
}
|
||||
}
|
||||
|
||||
Status s;
|
||||
Node* const_node = graph->AddNode(const_def, &s);
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
graph->AddControlEdge(var_node, const_node);
|
||||
std::vector<const Edge*> out_edges(n->out_edges().begin(),
|
||||
n->out_edges().end());
|
||||
for (const Edge* e : out_edges) {
|
||||
if (e->IsControlEdge()) {
|
||||
graph->AddControlEdge(const_node, e->dst());
|
||||
graph->RemoveEdge(e);
|
||||
} else {
|
||||
Node* dst = e->dst();
|
||||
int dst_input = e->dst_input();
|
||||
graph->RemoveEdge(e);
|
||||
graph->AddEdge(const_node, 0, dst, dst_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge node causes a loop so we remove NextIteration->Merge edge before
|
||||
// performing shape inference. But removing those edges also prevents us
|
||||
// from inferring output shape for Merge node (we need shapes for all its
|
||||
|
@ -196,7 +260,7 @@ Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
|
|||
// the shape inference is complete.
|
||||
BackEdgeHelper back_edge;
|
||||
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes,
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
|
||||
back_edge.RemovedEdges(), &shape_refiner));
|
||||
TF_RETURN_IF_ERROR(back_edge.Replace());
|
||||
|
||||
|
|
Loading…
Reference in New Issue