Minor tweak to shape inference to avoid an extra always-true comparison (NFC)

The loop is comparing every elements to the first one in the range, we can drop_front()
from the range since the first element does not need to compare to iself.

PiperOrigin-RevId: 356317270
Change-Id: Ieb54a4308e86bf9b510d81fa7bc46a32a361589c
This commit is contained in:
Mehdi Amini 2021-02-08 11:50:26 -08:00 committed by TensorFlower Gardener
parent 05fb0863d3
commit 80375c5b11

View File

@ -1146,7 +1146,7 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
DCOMMENT("Propating shape to" << func.getName());
ArrayRef<Operation*> callers = GetCallers(func);
if (!llvm::hasSingleElement(callers) &&
!llvm::all_of(callers, [&](Operation* caller) {
!llvm::all_of(callers.drop_front(), [&](Operation* caller) {
/// TODO(aminim): this is overly conservative as some operations
/// (like TPUPartitionedCallOp) may have extra operands that aren't
/// propagated to the callee.
@ -1438,8 +1438,7 @@ LogicalResult ShapeInference::TryToFold(Operation* op) {
RecordValue(ValuePort(std::get<0>(result)), attr);
} else {
DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
RefineResultType(op,
std::get<0>(result), value.getType());
RefineResultType(op, std::get<0>(result), value.getType());
}
}
@ -1654,8 +1653,8 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
return success();
}
int64_t producer = producer_or.ValueOrDie();
// TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if it is no
// longer needed.
// TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
// it is no longer needed.
ShapeInference context(producer, module.getContext(),
/*propagate_caller_callee_constants=*/false);
if (auto main = module.lookupSymbol<mlir::FuncOp>("main"))
@ -1667,7 +1666,7 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) {
while (!context.EmptyQueue()) {
FuncOp func = context.front();
if (failed(InferShapeForFunction(context, func, max_iterations)))
return failure();
return failure();
context.pop_front();
if ((--max_iteration) == 0) {