Add more debugging capabilities to static / symbolic shape inference in Grappler.
* ValidateSymbolicShapeManager() checks whether there is any conflicts in inferred shapes / dims and symbolically merged shapes / dims (runs if VLOG >= 1) * VerboseShapeInferenceLogging() dumps node shape inference detail; as a graph usually has many nodes, users should add nodes of interest to node_names_for_logging set to limit logging, otherwise, the current implementation just skips this detailed logging (runs if VLOG >= 3). Note that there's no change in functionalities of shape inference. This change only adds more logging when requested. PiperOrigin-RevId: 351618668 Change-Id: I7e0e972910c4fcd7d2ddb7c4fc0220b20e09cab5
This commit is contained in:
parent
78fdd635ab
commit
ee6aba78b6
@ -2062,11 +2062,132 @@ class SymbolicShapeManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns merged shape with merged dimensions.
|
||||
ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
|
||||
const auto& actual_shape = shapes_.GetMergedValue(s);
|
||||
if (!InferenceContext::RankKnown(actual_shape)) {
|
||||
return ic->UnknownShape();
|
||||
} else {
|
||||
std::vector<DimensionHandle> dims;
|
||||
for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
|
||||
shape_inference::DimensionHandle dim =
|
||||
InferenceContext::DimKnownRank(actual_shape, j);
|
||||
int64 d = dims_.GetMergedValue(dim);
|
||||
// Symbolic shape manager may made some dims < -1, which causes errors
|
||||
// in creating Dimension.
|
||||
if (d < -1) {
|
||||
d = -1;
|
||||
}
|
||||
dims.push_back(ic->MakeDim(d));
|
||||
}
|
||||
return ic->MakeShape(dims);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DisjointSet<shape_inference::ShapeHandle> shapes_;
|
||||
DisjointSet<shape_inference::DimensionHandle> dims_;
|
||||
};
|
||||
|
||||
// Checks whether there is any conflict in merged shapes and dims in
|
||||
// SymbolicShapeManager.
|
||||
Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
|
||||
SymbolicShapeRefiner* refiner,
|
||||
SymbolicShapeManager* shape_manager) {
|
||||
if (!VLOG_IS_ON(1)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Checking any conflics in shapes and dimensions ...";
|
||||
int64 num_incompatible_shapes = 0;
|
||||
for (const NodeDef& node : graph_def.node()) {
|
||||
auto ctx = refiner->GetNodeContext(&node);
|
||||
if (!ctx) {
|
||||
continue;
|
||||
}
|
||||
auto* ic = ctx->inference_context.get();
|
||||
for (int i = 0; i < ic->num_inputs(); ++i) {
|
||||
const auto& shape = ic->input(i);
|
||||
const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
|
||||
if (!refiner->CompatibleShapes(shape, merged_shape)) {
|
||||
num_incompatible_shapes++;
|
||||
VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
|
||||
<< "for node " << node.name() << " input (" << i << ") "
|
||||
<< ic->DebugString(shape)
|
||||
<< " vs. merged: " << ic->DebugString(merged_shape);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < ic->num_outputs(); ++i) {
|
||||
const auto& shape = ic->output(i);
|
||||
const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
|
||||
if (!refiner->CompatibleShapes(shape, merged_shape)) {
|
||||
num_incompatible_shapes++;
|
||||
VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
|
||||
<< "for node " << node.name() << " output (" << i << ") "
|
||||
<< ic->DebugString(shape)
|
||||
<< " vs. merged: " << ic->DebugString(merged_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (num_incompatible_shapes > 0) {
|
||||
VLOG(1) << "**** WARNING: " << num_incompatible_shapes
|
||||
<< " incompatible shapes from SymbolicShapeManager.";
|
||||
} else {
|
||||
VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Log shape inference and its merged shapes.
|
||||
Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
|
||||
SymbolicShapeRefiner* refiner,
|
||||
SymbolicShapeManager* shape_manager) {
|
||||
// As logging all the nodes would generate too many lines, we by default
|
||||
// skip this detailed logging. Users may add nodes of interest to
|
||||
// node_names_for_logging to enable detailed logging.
|
||||
absl::flat_hash_set<std::string> node_names_for_logging = {};
|
||||
if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto should_log = [&node_names_for_logging](std::string node_name) {
|
||||
return node_names_for_logging.find(node_name) !=
|
||||
node_names_for_logging.end();
|
||||
};
|
||||
|
||||
for (const NodeDef& node : graph_def.node()) {
|
||||
if (!should_log(node.name())) {
|
||||
continue;
|
||||
}
|
||||
auto ctx = refiner->GetNodeContext(&node);
|
||||
if (!ctx) {
|
||||
continue;
|
||||
}
|
||||
auto* ic = ctx->inference_context.get();
|
||||
VLOG(3) << "Shape inference for node : " << node.name();
|
||||
VLOG(3) << ctx->DebugString(node);
|
||||
std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
|
||||
for (int i = 0; i < ic->num_inputs(); ++i) {
|
||||
absl::StrAppend(
|
||||
&merged_shapes, " input[", i, "] -- ",
|
||||
ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
|
||||
"\n");
|
||||
}
|
||||
for (int i = 0; i < ic->num_outputs(); ++i) {
|
||||
absl::StrAppend(
|
||||
&merged_shapes, " output[", i, "] -- ",
|
||||
ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
|
||||
"\n");
|
||||
}
|
||||
VLOG(3) << merged_shapes;
|
||||
VLOG(3) << "--------------------------------";
|
||||
VLOG(3) << "";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
|
||||
SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
|
||||
const std::vector<ShapeAndType>& shapes_and_types,
|
||||
@ -2488,8 +2609,11 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
|
||||
shape_manager.get()));
|
||||
|
||||
for (const NodeDef& node : item_.graph.node()) {
|
||||
VLOG(3) << "Filling in graph properties for node: " << node.name();
|
||||
VLOG(4) << "Filling in graph properties for node: " << node.name();
|
||||
auto ctx = refiner->GetNodeContext(&node);
|
||||
if (!ctx) {
|
||||
continue;
|
||||
@ -2583,6 +2707,9 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
|
||||
output_properties_);
|
||||
|
||||
TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
|
||||
shape_manager.get()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user