From 80375c5b11eededcfee33b51ebf638993be6ca6c Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 8 Feb 2021 11:50:26 -0800 Subject: [PATCH] Minor tweak to shape inference to avoid an extra always-true comparison (NFC) The loop is comparing every elements to the first one in the range, we can drop_front() from the range since the first element does not need to compare to iself. PiperOrigin-RevId: 356317270 Change-Id: Ieb54a4308e86bf9b510d81fa7bc46a32a361589c --- .../mlir/tensorflow/transforms/shape_inference.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index ea8d35e5e55..f4a22dfa865 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -1146,7 +1146,7 @@ LogicalResult ShapeInference::PropagateShapeToFunctions( DCOMMENT("Propating shape to" << func.getName()); ArrayRef callers = GetCallers(func); if (!llvm::hasSingleElement(callers) && - !llvm::all_of(callers, [&](Operation* caller) { + !llvm::all_of(callers.drop_front(), [&](Operation* caller) { /// TODO(aminim): this is overly conservative as some operations /// (like TPUPartitionedCallOp) may have extra operands that aren't /// propagated to the callee. @@ -1438,8 +1438,7 @@ LogicalResult ShapeInference::TryToFold(Operation* op) { RecordValue(ValuePort(std::get<0>(result)), attr); } else { DCOMMENT("\t\tValue result unmapped, consider value type:" << value); - RefineResultType(op, - std::get<0>(result), value.getType()); + RefineResultType(op, std::get<0>(result), value.getType()); } } @@ -1654,8 +1653,8 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) { return success(); } int64_t producer = producer_or.ValueOrDie(); - // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if it is no - // longer needed. + // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if + // it is no longer needed. ShapeInference context(producer, module.getContext(), /*propagate_caller_callee_constants=*/false); if (auto main = module.lookupSymbol("main")) @@ -1667,7 +1666,7 @@ LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) { while (!context.EmptyQueue()) { FuncOp func = context.front(); if (failed(InferShapeForFunction(context, func, max_iterations))) - return failure(); + return failure(); context.pop_front(); if ((--max_iteration) == 0) {