[MLIR][NFC] Rename then_func() and friends to spell out function.

- This is in preparation of adding similar functions for CaseOp.

PiperOrigin-RevId: 332352891
Change-Id: Ifd6802857aef0392159c3221c1f42061fc83e1e4
This commit is contained in:
Rahul Joshi 2020-09-17 17:40:01 -07:00 committed by TensorFlower Gardener
parent c2176c4121
commit c643fbcc96
19 changed files with 70 additions and 67 deletions

View File

@ -64,8 +64,8 @@ void RunOnWhile(TF::WhileOp while_op) {
// Mark old function as private so that it can be DCE'd if not called.
func.setVisibility(SymbolTable::Visibility::Private);
};
create_region_with_call(while_op.cond_func(), new_op.cond());
create_region_with_call(while_op.body_func(), new_op.body());
create_region_with_call(while_op.cond_function(), new_op.cond());
create_region_with_call(while_op.body_function(), new_op.body());
op->replaceAllUsesWith(new_op.getResults());
op->erase();

View File

@ -749,7 +749,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) {
// Changes the function type of `cond_func` and `body_func` for the given While
// op.
LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
for (FuncOp func : {op.cond_func(), op.body_func()}) {
for (FuncOp func : {op.cond_function(), op.body_function()}) {
if (!func) continue;
FunctionType func_type = func.getType();

View File

@ -83,8 +83,8 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
if (!llvm::hasSingleElement(parent_op)) return failure();
// Find the then and else branch functions.
FuncOp then_func = op.then_func();
FuncOp else_func = op.else_func();
FuncOp then_func = op.then_function();
FuncOp else_func = op.else_function();
// If the If has no uses and its functions are side-effect free, then
// remove.

View File

@ -351,7 +351,7 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
result);
} else if (auto while_op = dyn_cast<WhileOp>(op)) {
AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc(
while_op.body_func()));
while_op.body_function()));
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion(
while_region.body()));
@ -364,8 +364,9 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
} else if (auto if_op = dyn_cast<IfOp>(op)) {
AnalyzeFunctionalCaseOrIfOp(if_op, {if_op.then_func(), if_op.else_func()},
backtrack_analysis);
AnalyzeFunctionalCaseOrIfOp(
if_op, {if_op.then_function(), if_op.else_function()},
backtrack_analysis);
} else if (llvm::isa<CaseRegionOp, IfRegionOp>(op)) {
AnalyzeRegionCaseOrIfOp(op, backtrack_analysis);
} else if (auto call = dyn_cast<CallOpInterface>(op)) {

View File

@ -305,12 +305,12 @@ else_branch: A function that takes 'inputs' and returns a list of
let extraClassDeclaration = [{
// Get the then branch function.
FuncOp then_func() {
FuncOp then_function() {
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, then_branch());
}
// Get the else branch function.
FuncOp else_func() {
FuncOp else_function() {
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, else_branch());
}
}];
@ -661,12 +661,12 @@ body: A function that takes a list of tensors and returns another
let extraClassDeclaration = [{
// Get the condition function.
FuncOp cond_func() {
FuncOp cond_function() {
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, cond());
}
// Get the body function.
FuncOp body_func() {
FuncOp body_function() {
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, body());
}
}];

View File

@ -2288,8 +2288,8 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
}
static LogicalResult Verify(WhileOp op) {
auto cond_fn = op.cond_func();
auto body_fn = op.body_func();
auto cond_fn = op.cond_function();
auto body_fn = op.body_function();
if (!cond_fn) {
return op.emitOpError("cond refers to an undefined function : ")
<< op.cond();

View File

@ -181,14 +181,14 @@ llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
llvm::function_ref<llvm::Optional<Type>(Operation*)> infer_from_op) {
for (auto& use : collection.getUses()) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
auto body = while_op.body_func();
auto body = while_op.body_function();
assert(body);
auto type_from_body = GetElementTypeFromAccess(
body.getArgument(use.getOperandNumber()), module, infer_from_op);
if (type_from_body.hasValue()) return type_from_body;
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(use.getOwner())) {
auto then_branch = if_op.then_func();
auto else_branch = if_op.else_func();
auto then_branch = if_op.then_function();
auto else_branch = if_op.else_function();
assert(then_branch && else_branch);
auto type_from_then = GetElementTypeFromAccess(
then_branch.getArgument(use.getOperandNumber() - 1), module,

View File

@ -157,14 +157,14 @@ static LogicalResult LowerIfOp(IfOp op) {
// Set up the 'then' block.
Block* then_block = builder.createBlock(merge_block);
Operation* call_op = CallFn(loc, get_operand, op.then_func(), &builder);
Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder);
auto get_then_result = [&](int i) { return call_op->getResult(i); };
JumpToBlock(loc, get_then_result, merge_block, &builder);
// Set up the 'else' block.
Block* else_block = builder.createBlock(merge_block);
call_op = CallFn(loc, get_operand, op.else_func(), &builder);
call_op = CallFn(loc, get_operand, op.else_function(), &builder);
auto get_else_result = [&](int i) { return call_op->getResult(i); };
JumpToBlock(loc, get_else_result, merge_block, &builder);
@ -190,8 +190,8 @@ static LogicalResult LowerWhileOp(WhileOp op) {
OpBuilder builder(op_inst);
auto cond_fn = op.cond_func();
auto body_fn = op.body_func();
auto cond_fn = op.cond_function();
auto body_fn = op.body_function();
// Split the block containing the While op into two blocks. One containing
// operations before the While op and other containing the rest. Create two

View File

@ -98,10 +98,10 @@ LogicalResult ConvertIfOp(IfOp if_op) {
if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless());
CopyDeviceAndUnderscoredAttributes(if_op, if_region);
CreateCall(if_op, if_op.then_func(),
CreateCall(if_op, if_op.then_function(),
/*caller_region=*/if_region.then_branch(), if_op.input(),
/*use_region_args=*/false);
CreateCall(if_op, if_op.else_func(),
CreateCall(if_op, if_op.else_function(),
/*caller_region=*/if_region.else_branch(), if_op.input(),
/*use_region_args=*/false);
if_op.replaceAllUsesWith(if_region.getResults());
@ -116,14 +116,14 @@ LogicalResult ConvertWhileOp(WhileOp while_op) {
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
YieldOp cond_yield =
CreateCall(while_op, while_op.cond_func(),
CreateCall(while_op, while_op.cond_function(),
/*caller_region=*/while_region.cond(), while_op.input(),
/*use_region_args=*/true);
Value i1_cond =
ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0));
cond_yield.setOperand(0, i1_cond);
CreateCall(while_op, while_op.body_func(),
CreateCall(while_op, while_op.body_function(),
/*caller_region=*/while_region.body(), while_op.input(),
/*use_region_args=*/true);
while_op.replaceAllUsesWith(while_region.getResults());

View File

@ -109,13 +109,14 @@ class ResourceAnalyzer {
return;
}
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
for (auto callee : {if_op.then_func(), if_op.else_func()}) {
for (auto callee : {if_op.then_function(), if_op.else_function()}) {
PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input());
}
return;
}
if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
for (auto callee : {while_op.cond_func(), while_op.body_func()}) {
for (auto callee :
{while_op.cond_function(), while_op.body_function()}) {
PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input());
}
return;

View File

@ -283,12 +283,13 @@ void ResourceDeviceInference::runOnOperation() {
if (auto while_op = dyn_cast<WhileOp>(op)) {
if (failed(propagate_operands_to_callee_arguments(
while_op, while_op.getOperands(),
{while_op.body_func(), while_op.cond_func()}, func_res)))
{while_op.body_function(), while_op.cond_function()},
func_res)))
return WalkResult::interrupt();
} else if (auto if_op = dyn_cast<IfOp>(op)) {
if (failed(propagate_operands_to_callee_arguments(
if_op, if_op.input(), {if_op.then_func(), if_op.else_func()},
func_res)))
if_op, if_op.input(),
{if_op.then_function(), if_op.else_function()}, func_res)))
return WalkResult::interrupt();
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
auto func = dyn_cast<FuncOp>(call.resolveCallable());

View File

@ -1204,8 +1204,8 @@ LogicalResult HoistForControlFlow(
lifted_partitioned_call_callees) {
for (Operation& op : llvm::make_early_inc_range(*block)) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
auto body = while_op.body_func();
auto cond = while_op.cond_func();
auto body = while_op.body_function();
auto cond = while_op.cond_function();
// Recursively handle the nested control flow.
HoistForControlFlow(&body.front(), module,
lifted_partitioned_call_callees);
@ -1213,8 +1213,8 @@ LogicalResult HoistForControlFlow(
lifted_partitioned_call_callees);
if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch = if_op.then_func();
auto else_branch = if_op.else_func();
auto then_branch = if_op.then_function();
auto else_branch = if_op.else_function();
// Recursively handle the nested control flow.
HoistForControlFlow(&then_branch.front(), module,
lifted_partitioned_call_callees);

View File

@ -188,8 +188,8 @@ void EliminateUnusedResultsForIfCase(Operation *op, ArrayRef<FuncOp> branches) {
// Eliminated unused results from a functional while.
void EliminateUnusedResultsForWhile(TF::WhileOp op) {
FuncOp cond = op.cond_func();
FuncOp body = op.body_func();
FuncOp cond = op.cond_function();
FuncOp body = op.body_function();
llvm::BitVector can_eliminate(op.getNumResults());
for (OpResult result : llvm::reverse(op.getResults())) {
@ -304,14 +304,14 @@ LogicalResult CanonicalizeFunctionalIfCase(Operation *op,
// Canonicalizes a functional while. Forwards common argument to results and
// drop resource results if posible.
LogicalResult CanonicalizeFunctionalWhile(TF::WhileOp op) {
for (FuncOp func : {op.cond_func(), op.body_func()}) {
for (FuncOp func : {op.cond_function(), op.body_function()}) {
if (failed(CleanupAndCanonicalize(func))) return failure();
}
// For while, just use the body function to forward operand to result.
bool has_resource_result = false;
if (failed(ForwardCommonArgToOutput(op, {op.body_func()}, op.getOperands(),
has_resource_result)))
if (failed(ForwardCommonArgToOutput(op, {op.body_function()},
op.getOperands(), has_resource_result)))
return failure();
// If no resource type results were found, no further cleanup needed.
if (!has_resource_result) return success();
@ -412,7 +412,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
result = CanonicalizeFunctionalIfCase(
op, {if_op.then_func(), if_op.else_func()}, if_op.input());
op, {if_op.then_function(), if_op.else_function()}, if_op.input());
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
SmallVector<FuncOp, 4> branches;
for (Attribute branch : case_op.branches()) {
@ -422,7 +422,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
}
result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input());
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
if (while_op.cond_func().walk(check_while_cond).wasInterrupted())
if (while_op.cond_function().walk(check_while_cond).wasInterrupted())
return WalkResult::interrupt();
result = CanonicalizeFunctionalWhile(while_op);
} else if (isa<TF::IfRegionOp, TF::CaseRegionOp, tf_device::ClusterOp>(

View File

@ -241,8 +241,8 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
// function result types.
bool InferShapeForIf(IfOp op) {
bool changed = false;
auto then_results = op.then_func().getType().getResults();
auto else_results = op.else_func().getType().getResults();
auto then_results = op.then_function().getType().getResults();
auto else_results = op.else_function().getType().getResults();
for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
// If then and else types do not match, skip refinement for that result.
if (std::get<1>(it) != std::get<2>(it)) continue;
@ -924,7 +924,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
return PropagateShapeToFunctions(
module, drop_begin(if_op.getOperandTypes(), 1),
{if_op.then_func(), if_op.else_func()}, max_iteration);
{if_op.then_function(), if_op.else_function()}, max_iteration);
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
SmallVector<FuncOp, 4> branches;
for (Attribute branch : case_op.branches()) {
@ -937,7 +937,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
return PropagateShapeToFunctions(
module, while_op.getOperandTypes(),
{while_op.cond_func(), while_op.body_func()}, max_iteration);
{while_op.cond_function(), while_op.body_function()}, max_iteration);
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) {
PropagateConstantToCallee(call_op, func, module);

View File

@ -163,7 +163,7 @@ LogicalResult HandleWhileOp(
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
llvm::StringMap<PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
auto body = while_op.body_func();
auto body = while_op.body_function();
llvm::SmallDenseMap<Value, Value> body_map;
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
auto it = data_var_to_size_var.find(while_op.getOperand(index));
@ -187,7 +187,7 @@ LogicalResult HandleWhileOp(
return failure();
}
// Cond should not change stacks in the arguments, so use an empty map.
auto cond = while_op.cond_func();
auto cond = while_op.cond_function();
ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
llvm::SmallDenseMap<Value, Value> empty_map;
if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
@ -231,8 +231,8 @@ LogicalResult HandleIfOp(
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
llvm::StringMap<PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
auto then_func = if_op.then_func();
auto else_func = if_op.else_func();
auto then_func = if_op.then_function();
auto else_func = if_op.else_function();
llvm::SmallDenseMap<Value, Value> then_map;
llvm::SmallDenseMap<Value, Value> else_map;

View File

@ -443,12 +443,12 @@ llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
insert(grad.handle(), grad.source().str());
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
for (const auto& entry : AccessedGradients(
{while_op.body_func(), while_op.cond_func()}, module))
{while_op.body_function(), while_op.cond_function()}, module))
for (const string& source : entry.getSecond())
insert(while_op.getOperand(entry.getFirst()), source);
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
for (const auto& entry :
AccessedGradients({if_op.then_func(), if_op.else_func()}, module))
for (const auto& entry : AccessedGradients(
{if_op.then_function(), if_op.else_function()}, module))
for (const string& source : entry.getSecond())
insert(if_op.getOperand(entry.getFirst() + 1), source);
} else if (auto call = llvm::dyn_cast<CallOpInterface>(&op)) {
@ -509,8 +509,8 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto body = while_op.body_func();
auto cond = while_op.cond_func();
auto body = while_op.body_function();
auto cond = while_op.cond_function();
auto grads = AccessedGradients({body, cond}, module);
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(while_op.getOperand(index));
@ -592,8 +592,8 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto then_branch = if_op.then_func();
auto else_branch = if_op.else_func();
auto then_branch = if_op.then_function();
auto else_branch = if_op.else_function();
auto grads = AccessedGradients({then_branch, else_branch}, module);
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(if_op.getOperand(index + 1));

View File

@ -155,7 +155,7 @@ LogicalResult HandleWhileOp(
llvm::StringMap<PartitionedCallDecompositionInfo>*
decomposed_partitioned_call_callees) {
// Rewrite body.
auto body = while_op.body_func();
auto body = while_op.body_function();
llvm::SmallDenseMap<Value, SizeInfo> body_map;
auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
auto it = buffer_to_size->find(while_op.getOperand(index));
@ -176,7 +176,7 @@ LogicalResult HandleWhileOp(
auto output_buffer_to_size = AddTensorListSizesToReturn(body, body_map);
// Rewrite cond.
auto cond = while_op.cond_func();
auto cond = while_op.cond_function();
llvm::SmallDenseMap<Value, SizeInfo> cond_map;
ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
find_arg_tensor_list_type, arg_buffer_size_is_fixed);
@ -701,9 +701,9 @@ LogicalResult DecomposeTensorListOpsInternal(
return failure();
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
if (failed(HandleCaseOrIfOp(if_op, {if_op.then_func(), if_op.else_func()},
module, buffer_to_size,
decomposed_partitioned_call_callees))) {
if (failed(HandleCaseOrIfOp(
if_op, {if_op.then_function(), if_op.else_function()}, module,
buffer_to_size, decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {

View File

@ -452,8 +452,8 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
return;
FuncOp body = while_op.body_func();
FuncOp cond = while_op.cond_func();
FuncOp body = while_op.body_function();
FuncOp cond = while_op.cond_function();
// Analyze the formattable inputs.
auto execute_arg_to_outer_args =

View File

@ -119,8 +119,8 @@ void LowerIf(TF::IfOp op, ModuleOp module) {
// Import the regions for both the true and false cases. These regions
// must be updated to tuple the return results together and use the xla hlo
// return op.
ImportXlaRegion(op.then_func(), &if_op.true_branch(), loc);
ImportXlaRegion(op.else_func(), &if_op.false_branch(), loc);
ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc);
ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc);
// De-tuple the results of the xla hlo if result.
Detuple(if_op.getResult(), op.getResults(), &builder);
@ -172,8 +172,8 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) {
// Import the regions for both the cond and body. These regions must be
// updated to tuple the return results together and use the xla hlo return op.
ImportXlaRegion(op.body_func(), &while_op.body(), loc);
ImportXlaRegion(op.cond_func(), &while_op.cond(), loc,
ImportXlaRegion(op.body_function(), &while_op.body(), loc);
ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
/*tuple_return=*/false);
// De-tuple the results of the xla hlo while.