[MLIR][NFC] Adopt FuncOp/Region argument API's.
- Use FuncOp::getArguments() and Region::getArguments() and friends where possible instead of going through the front() block. PiperOrigin-RevId: 325352975 Change-Id: Ib3dcfed692c0e04c554120a748f82e9efe009b89
This commit is contained in:
parent
8c0c1e1730
commit
3b47c2bdea
@ -147,9 +147,9 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||||||
// Now copy over the actual body of the reduction, leaving out the
|
// Now copy over the actual body of the reduction, leaving out the
|
||||||
// terminator.
|
// terminator.
|
||||||
BlockAndValueMapping mapping;
|
BlockAndValueMapping mapping;
|
||||||
mapping.map(reduce_op.body().front().getArgument(0), accumulator);
|
mapping.map(reduce_op.body().getArgument(0), accumulator);
|
||||||
mapping.map(reduce_op.body().front().getArgument(1), rhs);
|
mapping.map(reduce_op.body().getArgument(1), rhs);
|
||||||
mapping.map(reduce_op.body().front().getArgument(2), accumulator);
|
mapping.map(reduce_op.body().getArgument(2), accumulator);
|
||||||
for (auto& nested : reduce_op.body().front().without_terminator()) {
|
for (auto& nested : reduce_op.body().front().without_terminator()) {
|
||||||
auto clone = rewriter.clone(nested, mapping);
|
auto clone = rewriter.clone(nested, mapping);
|
||||||
for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
|
for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
|
||||||
|
@ -80,7 +80,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
// The basic block arguments correspond to values that are loop carried, while
|
// The basic block arguments correspond to values that are loop carried, while
|
||||||
// all those post are loop independent. Initialize extern_values with while_op
|
// all those post are loop independent. Initialize extern_values with while_op
|
||||||
// not loop carried operands.
|
// not loop carried operands.
|
||||||
auto num_loop_carried = while_op.cond().front().getNumArguments();
|
auto num_loop_carried = while_op.cond().getNumArguments();
|
||||||
auto not_carried_operands =
|
auto not_carried_operands =
|
||||||
while_op.getOperands().drop_front(num_loop_carried);
|
while_op.getOperands().drop_front(num_loop_carried);
|
||||||
extern_values.insert(not_carried_operands.begin(),
|
extern_values.insert(not_carried_operands.begin(),
|
||||||
@ -124,8 +124,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
// Collect new types.
|
// Collect new types.
|
||||||
SmallVector<Type, 4> types;
|
SmallVector<Type, 4> types;
|
||||||
types.reserve(extra_operands.size() + while_op.getNumOperands());
|
types.reserve(extra_operands.size() + while_op.getNumOperands());
|
||||||
for (BlockArgument ba : while_op.cond().front().getArguments())
|
for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type);
|
||||||
types.push_back(ba.getType());
|
|
||||||
for (Value operand : extern_values) types.push_back(operand.getType());
|
for (Value operand : extern_values) types.push_back(operand.getType());
|
||||||
|
|
||||||
// Create outline function from region. Optional pass extra arguments through
|
// Create outline function from region. Optional pass extra arguments through
|
||||||
|
@ -2873,7 +2873,7 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
|
|||||||
mlir::OpBuilder builder(func.getBody());
|
mlir::OpBuilder builder(func.getBody());
|
||||||
llvm::SmallVector<mlir::Type, 4> new_input_types;
|
llvm::SmallVector<mlir::Type, 4> new_input_types;
|
||||||
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
|
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
|
||||||
auto arg = func.front().getArgument(i);
|
auto arg = func.getArgument(i);
|
||||||
auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
|
auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType<
|
||||||
mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
|
mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table);
|
||||||
if (global_tensor) {
|
if (global_tensor) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user