diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index a93eed3ce80..810332e93ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Traits.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir @@ -933,23 +934,19 @@ struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> { auto fetch_op = llvm::cast<FetchOp>(block.back()); auto island_op = llvm::cast<IslandOp>(block.front()); - auto yield_op = llvm::cast<YieldOp>(island_op.GetBody().back()); + Operation &yield_op = island_op.GetBody().back(); - // Mapping from island results to inner ops results. - llvm::SmallDenseMap<Value *, Value *, 8> island_rets_to_ops; - for (auto ops_and_ret_vals : - llvm::zip(island_op.outputs(), yield_op.fetches())) { - island_rets_to_ops.insert( - {std::get<0>(ops_and_ret_vals), std::get<1>(ops_and_ret_vals)}); - } - - // Map graph results to inner ops results. + // Map graph results to inner ops results of single island. llvm::SmallVector<Value *, 8> new_rets; - for (auto *fetch : fetch_op.fetches()) { - if (auto *op = island_rets_to_ops.lookup(fetch)) - new_rets.push_back(op); - else - new_rets.push_back(fetch); + for (Value *operand : fetch_op.fetches()) { + if (operand->getDefiningOp() != island_op) { + // Operand is not from island, simply propagate it out. + new_rets.push_back(operand); + } else { + // Lookup yield operand in island for inner op result. + auto result = llvm::cast<OpResult>(operand); + new_rets.push_back(yield_op.getOperand(result->getResultNumber())); + } } // Move inner ops from island to block containing graph.