Minor tweak to shape inference to avoid an extra always-true comparison (NFC)
The loop is comparing every elements to the first one in the range, we can drop_front() from the range since the first element does not need to compare to iself. PiperOrigin-RevId: 356317270 Change-Id: Ieb54a4308e86bf9b510d81fa7bc46a32a361589c
This commit is contained in:
parent
05fb0863d3
commit
80375c5b11
@ -1146,7 +1146,7 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
DCOMMENT("Propating shape to" << func.getName());
|
||||
ArrayRef<Operation*> callers = GetCallers(func);
|
||||
if (!llvm::hasSingleElement(callers) &&
|
||||
!llvm::all_of(callers, [&](Operation* caller) {
|
||||
!llvm::all_of(callers.drop_front(), [&](Operation* caller) {
|
||||
/// TODO(aminim): this is overly conservative as some operations
|
||||
/// (like TPUPartitionedCallOp) may have extra operands that aren't
|
||||
/// propagated to the callee.
|
||||
@ -1438,8 +1438,7 @@ LogicalResult ShapeInference::TryToFold(Operation* op) {
|
||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
||||
} else {
|
||||
DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
|
||||
RefineResultType(op,
|
||||
std::get<0>(result), value.getType());
|
||||
RefineResultType(op, std::get<0>(result), value.getType());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1654,8 +1653,8 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
|
||||
return success();
|
||||
}
|
||||
int64_t producer = producer_or.ValueOrDie();
|
||||
// TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if it is no
|
||||
// longer needed.
|
||||
// TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
|
||||
// it is no longer needed.
|
||||
ShapeInference context(producer, module.getContext(),
|
||||
/*propagate_caller_callee_constants=*/false);
|
||||
if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
|
||||
@ -1667,7 +1666,7 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
|
||||
while (!context.EmptyQueue()) {
|
||||
FuncOp func = context.front();
|
||||
if (failed(InferShapeForFunction(context, func, max_iterations)))
|
||||
return failure();
|
||||
return failure();
|
||||
context.pop_front();
|
||||
|
||||
if ((--max_iteration) == 0) {
|
||||
|
Loading…
Reference in New Issue
Block a user