Add more debugging capabilities to static / symbolic shape inference in Grappler.

* ValidateSymbolicShapeManager() checks whether there is any conflicts in inferred shapes / dims and symbolically merged shapes / dims (runs if VLOG >= 1)
* VerboseShapeInferenceLogging() dumps node shape inference detail; as a graph usually has many nodes, users should add nodes of interest to node_names_for_logging set to limit logging, otherwise, the current implementation just skips this detailed logging (runs if VLOG >= 3).

Note that there's no change in functionalities of shape inference. This change only adds more logging when requested.

PiperOrigin-RevId: 351618668
Change-Id: I7e0e972910c4fcd7d2ddb7c4fc0220b20e09cab5
This commit is contained in:
Doe Hyun Yoon 2021-01-13 10:45:36 -08:00 committed by TensorFlower Gardener
parent 78fdd635ab
commit ee6aba78b6

View File

@ -2062,11 +2062,132 @@ class SymbolicShapeManager {
}
}
// Returns merged shape with merged dimensions.
ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
const auto& actual_shape = shapes_.GetMergedValue(s);
if (!InferenceContext::RankKnown(actual_shape)) {
return ic->UnknownShape();
} else {
std::vector<DimensionHandle> dims;
for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
shape_inference::DimensionHandle dim =
InferenceContext::DimKnownRank(actual_shape, j);
int64 d = dims_.GetMergedValue(dim);
// Symbolic shape manager may made some dims < -1, which causes errors
// in creating Dimension.
if (d < -1) {
d = -1;
}
dims.push_back(ic->MakeDim(d));
}
return ic->MakeShape(dims);
}
}
private:
DisjointSet<shape_inference::ShapeHandle> shapes_;
DisjointSet<shape_inference::DimensionHandle> dims_;
};
// Checks whether there is any conflict in merged shapes and dims in
// SymbolicShapeManager.
Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
SymbolicShapeRefiner* refiner,
SymbolicShapeManager* shape_manager) {
if (!VLOG_IS_ON(1)) {
return Status::OK();
}
VLOG(1) << "Checking any conflics in shapes and dimensions ...";
int64 num_incompatible_shapes = 0;
for (const NodeDef& node : graph_def.node()) {
auto ctx = refiner->GetNodeContext(&node);
if (!ctx) {
continue;
}
auto* ic = ctx->inference_context.get();
for (int i = 0; i < ic->num_inputs(); ++i) {
const auto& shape = ic->input(i);
const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
if (!refiner->CompatibleShapes(shape, merged_shape)) {
num_incompatible_shapes++;
VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
<< "for node " << node.name() << " input (" << i << ") "
<< ic->DebugString(shape)
<< " vs. merged: " << ic->DebugString(merged_shape);
}
}
for (int i = 0; i < ic->num_outputs(); ++i) {
const auto& shape = ic->output(i);
const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
if (!refiner->CompatibleShapes(shape, merged_shape)) {
num_incompatible_shapes++;
VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
<< "for node " << node.name() << " output (" << i << ") "
<< ic->DebugString(shape)
<< " vs. merged: " << ic->DebugString(merged_shape);
}
}
}
if (num_incompatible_shapes > 0) {
VLOG(1) << "**** WARNING: " << num_incompatible_shapes
<< " incompatible shapes from SymbolicShapeManager.";
} else {
VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
}
return Status::OK();
}
// Log shape inference and its merged shapes.
Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
SymbolicShapeRefiner* refiner,
SymbolicShapeManager* shape_manager) {
// As logging all the nodes would generate too many lines, we by default
// skip this detailed logging. Users may add nodes of interest to
// node_names_for_logging to enable detailed logging.
absl::flat_hash_set<std::string> node_names_for_logging = {};
if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
return Status::OK();
}
auto should_log = [&node_names_for_logging](std::string node_name) {
return node_names_for_logging.find(node_name) !=
node_names_for_logging.end();
};
for (const NodeDef& node : graph_def.node()) {
if (!should_log(node.name())) {
continue;
}
auto ctx = refiner->GetNodeContext(&node);
if (!ctx) {
continue;
}
auto* ic = ctx->inference_context.get();
VLOG(3) << "Shape inference for node : " << node.name();
VLOG(3) << ctx->DebugString(node);
std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
for (int i = 0; i < ic->num_inputs(); ++i) {
absl::StrAppend(
&merged_shapes, " input[", i, "] -- ",
ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
"\n");
}
for (int i = 0; i < ic->num_outputs(); ++i) {
absl::StrAppend(
&merged_shapes, " output[", i, "] -- ",
ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
"\n");
}
VLOG(3) << merged_shapes;
VLOG(3) << "--------------------------------";
VLOG(3) << "";
}
return Status::OK();
}
Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
const std::vector<ShapeAndType>& shapes_and_types,
@ -2488,8 +2609,11 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
}
}
TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
shape_manager.get()));
for (const NodeDef& node : item_.graph.node()) {
VLOG(3) << "Filling in graph properties for node: " << node.name();
VLOG(4) << "Filling in graph properties for node: " << node.name();
auto ctx = refiner->GetNodeContext(&node);
if (!ctx) {
continue;
@ -2583,6 +2707,9 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
output_properties_);
TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
shape_manager.get()));
return Status::OK();
}