295 lines
12 KiB
C++
295 lines
12 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/jit/shape_inference.h"
|
|
|
|
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
|
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
|
#include "tensorflow/core/framework/node_def_util.h"
|
|
#include "tensorflow/core/framework/shape_inference.h"
|
|
#include "tensorflow/core/framework/tensor.pb.h"
|
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
|
#include "tensorflow/core/graph/algorithm.h"
|
|
#include "tensorflow/core/util/dump_graph.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
namespace {
|
|
|
|
// Converts a shape inference handle to a PartialTensorShape.
|
|
Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
|
|
const shape_inference::ShapeHandle& handle,
|
|
PartialTensorShape* shape) {
|
|
// The default is already unknown
|
|
if (!context->RankKnown(handle)) return Status::OK();
|
|
|
|
std::vector<int64> dims(context->Rank(handle));
|
|
for (int32 i = 0, end = dims.size(); i < end; ++i) {
|
|
dims[i] = context->Value(context->Dim(handle, i));
|
|
}
|
|
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
|
|
}
|
|
|
|
Status PropagateShapes(Graph* graph,
|
|
const std::map<int, InferredShape>& arg_shapes,
|
|
const std::vector<BackEdgeHelper::BackEdge>& back_edges,
|
|
ShapeRefiner* shape_refiner) {
|
|
std::map<const Node*, const Node*> merge_to_next_iteration;
|
|
for (const auto& e : back_edges) {
|
|
if (e.src->IsNextIteration() && e.dst->IsMerge()) {
|
|
merge_to_next_iteration[e.dst] = e.src;
|
|
}
|
|
}
|
|
|
|
// Visits the nodes in topological order (reverse post-order), inferring
|
|
// shapes.
|
|
// TODO(phawkins): handle cyclic graphs.
|
|
std::vector<Node*> order;
|
|
GetReversePostOrder(*graph, &order);
|
|
|
|
for (Node* n : order) {
|
|
// Ignore the status returned by the shape_refiner. We want the best effort
|
|
// shapes, even if no shape function is registered for a node.
|
|
Status status = shape_refiner->AddNode(n);
|
|
if (!status.ok()) {
|
|
VLOG(1) << "Shape inference failed for node " << n->name() << ": "
|
|
<< status;
|
|
} else {
|
|
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
|
|
for (int i = 0; i < n->num_outputs(); i++) {
|
|
shape_inference::ShapeHandle handle = context->output(i);
|
|
VLOG(4) << "Output " << i << " for node " << n->name() << ": "
|
|
<< context->DebugString(handle);
|
|
}
|
|
}
|
|
|
|
if (n->type_string() == "_Arg") {
|
|
int index;
|
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
|
auto it = arg_shapes.find(index);
|
|
if (it != arg_shapes.end()) {
|
|
const InferredShape& arg_shape = it->second;
|
|
shape_inference::InferenceContext* context =
|
|
shape_refiner->GetContext(n);
|
|
|
|
if (arg_shape.handle_type != DT_INVALID) {
|
|
shape_inference::ShapeHandle handle;
|
|
TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape(
|
|
arg_shape.handle_shape, &handle));
|
|
|
|
// Sets the shape and type of the variable's value.
|
|
context->set_output_handle_shapes_and_types(
|
|
0, std::vector<shape_inference::ShapeAndType>{
|
|
{handle, arg_shape.handle_type}});
|
|
}
|
|
|
|
shape_inference::ShapeHandle handle;
|
|
TF_RETURN_IF_ERROR(
|
|
context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle));
|
|
TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
|
|
}
|
|
}
|
|
|
|
// Sometimes we have VariableShape nodes in while loop (after Enter nodes).
|
|
// They won't be constant-folded because TensorFlow constant folding does
|
|
// not handle Enter nodes (and thus does not handle any nodes after Enter
|
|
// nodes). We try to replace such VariableShape nodes with Const nodes here.
|
|
if (n->type_string() == "VariableShape") {
|
|
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
|
|
auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
|
|
if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
|
|
shape_inference::ShapeHandle handle =
|
|
handle_shapes_and_types->at(0).shape;
|
|
TensorShapeProto shape_proto;
|
|
context->ShapeHandleToProto(handle, &shape_proto);
|
|
if (!shape_proto.unknown_rank()) {
|
|
NodeDef const_def;
|
|
const_def.set_op("Const");
|
|
Node* var_node;
|
|
TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
|
|
const_def.set_name(
|
|
graph->NewName(absl::StrCat("var_shape_", var_node->name())));
|
|
DataType dtype = n->output_type(0);
|
|
AddNodeAttr("dtype", dtype, &const_def);
|
|
TensorProto value;
|
|
value.set_dtype(dtype);
|
|
value.mutable_tensor_shape()->add_dim()->set_size(
|
|
shape_proto.dim_size());
|
|
for (const auto& dim : shape_proto.dim()) {
|
|
if (dtype == DT_INT32) {
|
|
value.add_int_val(dim.size());
|
|
} else {
|
|
value.add_int64_val(dim.size());
|
|
}
|
|
}
|
|
AddNodeAttr("value", value, &const_def);
|
|
for (auto const& attr : n->attrs()) {
|
|
if (*attr.first.begin() == '_') {
|
|
AddNodeAttr(attr.first, attr.second, &const_def);
|
|
}
|
|
}
|
|
|
|
Status s;
|
|
Node* const_node = graph->AddNode(const_def, &s);
|
|
TF_RETURN_IF_ERROR(s);
|
|
|
|
graph->AddControlEdge(var_node, const_node);
|
|
std::vector<const Edge*> out_edges(n->out_edges().begin(),
|
|
n->out_edges().end());
|
|
for (const Edge* e : out_edges) {
|
|
if (e->IsControlEdge()) {
|
|
graph->AddControlEdge(const_node, e->dst());
|
|
graph->RemoveEdge(e);
|
|
} else {
|
|
Node* dst = e->dst();
|
|
int dst_input = e->dst_input();
|
|
graph->RemoveEdge(e);
|
|
graph->AddEdge(const_node, 0, dst, dst_input);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Merge node causes a loop so we remove NextIteration->Merge edge before
|
|
// performing shape inference. But removing those edges also prevents us
|
|
// from inferring output shape for Merge node (we need shapes for all its
|
|
// inputs).
|
|
// For loop invariant resource input's Merge node, we set output resource
|
|
// shape as Enter node's resource shape.
|
|
// TODO(b/129367850): clean this up.
|
|
if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
|
|
// Check if this is a loop invariant input's Merge node. We do it by
|
|
// checking if corresponding NextIteration node comes from Switch node
|
|
// directly.
|
|
auto iter = merge_to_next_iteration.find(n);
|
|
if (iter != merge_to_next_iteration.end()) {
|
|
const Node *next_iter = iter->second, *node = next_iter;
|
|
do {
|
|
TF_RETURN_IF_ERROR(node->input_node(0, &node));
|
|
} while (node->IsIdentity());
|
|
const Node* switch_input;
|
|
bool is_loop_invariant = node->IsSwitch() &&
|
|
node->input_node(0, &switch_input).ok() &&
|
|
switch_input == n;
|
|
if (is_loop_invariant) {
|
|
shape_inference::InferenceContext* context =
|
|
shape_refiner->GetContext(n);
|
|
for (int i = 0; i < n->num_inputs(); i++) {
|
|
const Node* input_node;
|
|
if (n->input_node(i, &input_node).ok()) {
|
|
auto shapes_and_types = context->input_handle_shapes_and_types(i);
|
|
if (shapes_and_types) {
|
|
context->set_output_handle_shapes_and_types(0,
|
|
*shapes_and_types);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// Store the shapes of the output tensors in a map
|
|
Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner,
|
|
GraphShapeInfo* shape_info) {
|
|
for (const Node* node : graph.nodes()) {
|
|
shape_inference::InferenceContext* context = shape_refiner.GetContext(node);
|
|
if (!context) continue;
|
|
|
|
auto& outputs = (*shape_info)[node->name()];
|
|
outputs.resize(context->num_outputs());
|
|
for (int i = 0; i < context->num_outputs(); ++i) {
|
|
auto& output = outputs[i];
|
|
TF_RETURN_IF_ERROR(
|
|
ShapeHandleToTensorShape(context, context->output(i), &output.shape));
|
|
|
|
const auto* handle_shapes_and_types =
|
|
context->output_handle_shapes_and_types(i);
|
|
if (handle_shapes_and_types != nullptr) {
|
|
if (handle_shapes_and_types->size() == 1) {
|
|
TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
|
|
context, (*handle_shapes_and_types)[0].shape,
|
|
&output.handle_shape));
|
|
output.handle_type = (*handle_shapes_and_types)[0].dtype;
|
|
} else {
|
|
// otherwise, it may be resource like a Queue, which can have
|
|
// multiple shapes and types represented by a single handle.
|
|
}
|
|
}
|
|
VLOG(4) << node->name() << " output " << i << " shape"
|
|
<< output.shape.DebugString() << " handle_type "
|
|
<< DataTypeString(output.handle_type) << " handle_shape "
|
|
<< output.handle_shape.DebugString();
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
|
|
const tensorflow::FunctionLibraryDefinition* fnlib_def,
|
|
GraphShapeInfo* shape_info) {
|
|
ShapeRefiner shape_refiner(graph->versions(), graph->op_registry());
|
|
shape_refiner.set_require_shape_inference_fns(false);
|
|
// TODO(dlibenzi): Verify if it is worth trying to infer shaped within
|
|
// functions. Some functions can be called at multiple locations with
|
|
// difference shapes, which will trigger a shape inference based on the
|
|
// arguments passed at the first call.
|
|
// shape_refiner.set_function_library_for_shape_inference(fnlib_def);
|
|
|
|
// ShapeRefiner requires that all inputs of a node are present when
|
|
// ShapeRefiner::AddNode is called. To get at least some shape information in
|
|
// loops, we temporarily remove loop backedges and add them back again after
|
|
// the shape inference is complete.
|
|
BackEdgeHelper back_edge;
|
|
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
|
|
TF_RETURN_IF_ERROR(PropagateShapes(graph, arg_shapes,
|
|
back_edge.RemovedEdges(), &shape_refiner));
|
|
TF_RETURN_IF_ERROR(back_edge.Replace());
|
|
|
|
// Currently information does not flow "backward" from consumers to producers
|
|
// in the shape inference, but we consume the shapes in a second pass in case
|
|
// backward information flow is added in the future.
|
|
return StoreOutputShapes(*graph, shape_refiner, shape_info);
|
|
}
|
|
|
|
xla::StatusOr<InferredShape> MergeInferredShapes(const InferredShape& a,
|
|
const InferredShape& b) {
|
|
InferredShape result;
|
|
TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape));
|
|
|
|
if (a.handle_type == DT_INVALID) {
|
|
result.handle_type = b.handle_type;
|
|
} else if (b.handle_type == DT_INVALID) {
|
|
result.handle_type = a.handle_type;
|
|
} else if (a.handle_type == b.handle_type) {
|
|
result.handle_type = a.handle_type;
|
|
} else {
|
|
return errors::InvalidArgument(
|
|
"Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ",
|
|
DataTypeString(b.handle_type));
|
|
}
|
|
TF_RETURN_IF_ERROR(
|
|
a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape));
|
|
return result;
|
|
}
|
|
|
|
} // namespace tensorflow
|