[MLIR][NFC] Rename then_func() and friends to spell out function.
- This is in preparation of adding similar functions for CaseOp. PiperOrigin-RevId: 332352891 Change-Id: Ifd6802857aef0392159c3221c1f42061fc83e1e4
This commit is contained in:
parent
c2176c4121
commit
c643fbcc96
@ -64,8 +64,8 @@ void RunOnWhile(TF::WhileOp while_op) {
|
||||
// Mark old function as private so that it can be DCE'd if not called.
|
||||
func.setVisibility(SymbolTable::Visibility::Private);
|
||||
};
|
||||
create_region_with_call(while_op.cond_func(), new_op.cond());
|
||||
create_region_with_call(while_op.body_func(), new_op.body());
|
||||
create_region_with_call(while_op.cond_function(), new_op.cond());
|
||||
create_region_with_call(while_op.body_function(), new_op.body());
|
||||
|
||||
op->replaceAllUsesWith(new_op.getResults());
|
||||
op->erase();
|
||||
|
@ -749,7 +749,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) {
|
||||
// Changes the function type of `cond_func` and `body_func` for the given While
|
||||
// op.
|
||||
LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
for (FuncOp func : {op.cond_func(), op.body_func()}) {
|
||||
for (FuncOp func : {op.cond_function(), op.body_function()}) {
|
||||
if (!func) continue;
|
||||
|
||||
FunctionType func_type = func.getType();
|
||||
|
@ -83,8 +83,8 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
if (!llvm::hasSingleElement(parent_op)) return failure();
|
||||
|
||||
// Find the then and else branch functions.
|
||||
FuncOp then_func = op.then_func();
|
||||
FuncOp else_func = op.else_func();
|
||||
FuncOp then_func = op.then_function();
|
||||
FuncOp else_func = op.else_function();
|
||||
|
||||
// If the If has no uses and its functions are side-effect free, then
|
||||
// remove.
|
||||
|
@ -351,7 +351,7 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
||||
result);
|
||||
} else if (auto while_op = dyn_cast<WhileOp>(op)) {
|
||||
AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc(
|
||||
while_op.body_func()));
|
||||
while_op.body_function()));
|
||||
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
|
||||
AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion(
|
||||
while_region.body()));
|
||||
@ -364,8 +364,9 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
|
||||
|
||||
AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
|
||||
} else if (auto if_op = dyn_cast<IfOp>(op)) {
|
||||
AnalyzeFunctionalCaseOrIfOp(if_op, {if_op.then_func(), if_op.else_func()},
|
||||
backtrack_analysis);
|
||||
AnalyzeFunctionalCaseOrIfOp(
|
||||
if_op, {if_op.then_function(), if_op.else_function()},
|
||||
backtrack_analysis);
|
||||
} else if (llvm::isa<CaseRegionOp, IfRegionOp>(op)) {
|
||||
AnalyzeRegionCaseOrIfOp(op, backtrack_analysis);
|
||||
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
|
||||
|
@ -305,12 +305,12 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Get the then branch function.
|
||||
FuncOp then_func() {
|
||||
FuncOp then_function() {
|
||||
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, then_branch());
|
||||
}
|
||||
|
||||
// Get the else branch function.
|
||||
FuncOp else_func() {
|
||||
FuncOp else_function() {
|
||||
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, else_branch());
|
||||
}
|
||||
}];
|
||||
@ -661,12 +661,12 @@ body: A function that takes a list of tensors and returns another
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Get the condition function.
|
||||
FuncOp cond_func() {
|
||||
FuncOp cond_function() {
|
||||
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, cond());
|
||||
}
|
||||
|
||||
// Get the body function.
|
||||
FuncOp body_func() {
|
||||
FuncOp body_function() {
|
||||
return SymbolTable::lookupNearestSymbolFrom<FuncOp>(*this, body());
|
||||
}
|
||||
}];
|
||||
|
@ -2288,8 +2288,8 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
|
||||
}
|
||||
|
||||
static LogicalResult Verify(WhileOp op) {
|
||||
auto cond_fn = op.cond_func();
|
||||
auto body_fn = op.body_func();
|
||||
auto cond_fn = op.cond_function();
|
||||
auto body_fn = op.body_function();
|
||||
if (!cond_fn) {
|
||||
return op.emitOpError("cond refers to an undefined function : ")
|
||||
<< op.cond();
|
||||
|
@ -181,14 +181,14 @@ llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
|
||||
llvm::function_ref<llvm::Optional<Type>(Operation*)> infer_from_op) {
|
||||
for (auto& use : collection.getUses()) {
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
|
||||
auto body = while_op.body_func();
|
||||
auto body = while_op.body_function();
|
||||
assert(body);
|
||||
auto type_from_body = GetElementTypeFromAccess(
|
||||
body.getArgument(use.getOperandNumber()), module, infer_from_op);
|
||||
if (type_from_body.hasValue()) return type_from_body;
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(use.getOwner())) {
|
||||
auto then_branch = if_op.then_func();
|
||||
auto else_branch = if_op.else_func();
|
||||
auto then_branch = if_op.then_function();
|
||||
auto else_branch = if_op.else_function();
|
||||
assert(then_branch && else_branch);
|
||||
auto type_from_then = GetElementTypeFromAccess(
|
||||
then_branch.getArgument(use.getOperandNumber() - 1), module,
|
||||
|
@ -157,14 +157,14 @@ static LogicalResult LowerIfOp(IfOp op) {
|
||||
|
||||
// Set up the 'then' block.
|
||||
Block* then_block = builder.createBlock(merge_block);
|
||||
Operation* call_op = CallFn(loc, get_operand, op.then_func(), &builder);
|
||||
Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder);
|
||||
|
||||
auto get_then_result = [&](int i) { return call_op->getResult(i); };
|
||||
JumpToBlock(loc, get_then_result, merge_block, &builder);
|
||||
|
||||
// Set up the 'else' block.
|
||||
Block* else_block = builder.createBlock(merge_block);
|
||||
call_op = CallFn(loc, get_operand, op.else_func(), &builder);
|
||||
call_op = CallFn(loc, get_operand, op.else_function(), &builder);
|
||||
|
||||
auto get_else_result = [&](int i) { return call_op->getResult(i); };
|
||||
JumpToBlock(loc, get_else_result, merge_block, &builder);
|
||||
@ -190,8 +190,8 @@ static LogicalResult LowerWhileOp(WhileOp op) {
|
||||
|
||||
OpBuilder builder(op_inst);
|
||||
|
||||
auto cond_fn = op.cond_func();
|
||||
auto body_fn = op.body_func();
|
||||
auto cond_fn = op.cond_function();
|
||||
auto body_fn = op.body_function();
|
||||
|
||||
// Split the block containing the While op into two blocks. One containing
|
||||
// operations before the While op and other containing the rest. Create two
|
||||
|
@ -98,10 +98,10 @@ LogicalResult ConvertIfOp(IfOp if_op) {
|
||||
if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless());
|
||||
CopyDeviceAndUnderscoredAttributes(if_op, if_region);
|
||||
|
||||
CreateCall(if_op, if_op.then_func(),
|
||||
CreateCall(if_op, if_op.then_function(),
|
||||
/*caller_region=*/if_region.then_branch(), if_op.input(),
|
||||
/*use_region_args=*/false);
|
||||
CreateCall(if_op, if_op.else_func(),
|
||||
CreateCall(if_op, if_op.else_function(),
|
||||
/*caller_region=*/if_region.else_branch(), if_op.input(),
|
||||
/*use_region_args=*/false);
|
||||
if_op.replaceAllUsesWith(if_region.getResults());
|
||||
@ -116,14 +116,14 @@ LogicalResult ConvertWhileOp(WhileOp while_op) {
|
||||
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
|
||||
|
||||
YieldOp cond_yield =
|
||||
CreateCall(while_op, while_op.cond_func(),
|
||||
CreateCall(while_op, while_op.cond_function(),
|
||||
/*caller_region=*/while_region.cond(), while_op.input(),
|
||||
/*use_region_args=*/true);
|
||||
Value i1_cond =
|
||||
ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0));
|
||||
cond_yield.setOperand(0, i1_cond);
|
||||
|
||||
CreateCall(while_op, while_op.body_func(),
|
||||
CreateCall(while_op, while_op.body_function(),
|
||||
/*caller_region=*/while_region.body(), while_op.input(),
|
||||
/*use_region_args=*/true);
|
||||
while_op.replaceAllUsesWith(while_region.getResults());
|
||||
|
@ -109,13 +109,14 @@ class ResourceAnalyzer {
|
||||
return;
|
||||
}
|
||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||
for (auto callee : {if_op.then_func(), if_op.else_func()}) {
|
||||
for (auto callee : {if_op.then_function(), if_op.else_function()}) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
for (auto callee : {while_op.cond_func(), while_op.body_func()}) {
|
||||
for (auto callee :
|
||||
{while_op.cond_function(), while_op.body_function()}) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input());
|
||||
}
|
||||
return;
|
||||
|
@ -283,12 +283,13 @@ void ResourceDeviceInference::runOnOperation() {
|
||||
if (auto while_op = dyn_cast<WhileOp>(op)) {
|
||||
if (failed(propagate_operands_to_callee_arguments(
|
||||
while_op, while_op.getOperands(),
|
||||
{while_op.body_func(), while_op.cond_func()}, func_res)))
|
||||
{while_op.body_function(), while_op.cond_function()},
|
||||
func_res)))
|
||||
return WalkResult::interrupt();
|
||||
} else if (auto if_op = dyn_cast<IfOp>(op)) {
|
||||
if (failed(propagate_operands_to_callee_arguments(
|
||||
if_op, if_op.input(), {if_op.then_func(), if_op.else_func()},
|
||||
func_res)))
|
||||
if_op, if_op.input(),
|
||||
{if_op.then_function(), if_op.else_function()}, func_res)))
|
||||
return WalkResult::interrupt();
|
||||
} else if (auto call = dyn_cast<CallOpInterface>(op)) {
|
||||
auto func = dyn_cast<FuncOp>(call.resolveCallable());
|
||||
|
@ -1204,8 +1204,8 @@ LogicalResult HoistForControlFlow(
|
||||
lifted_partitioned_call_callees) {
|
||||
for (Operation& op : llvm::make_early_inc_range(*block)) {
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
auto body = while_op.body_func();
|
||||
auto cond = while_op.cond_func();
|
||||
auto body = while_op.body_function();
|
||||
auto cond = while_op.cond_function();
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForControlFlow(&body.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
@ -1213,8 +1213,8 @@ LogicalResult HoistForControlFlow(
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
auto then_branch = if_op.then_func();
|
||||
auto else_branch = if_op.else_func();
|
||||
auto then_branch = if_op.then_function();
|
||||
auto else_branch = if_op.else_function();
|
||||
// Recursively handle the nested control flow.
|
||||
HoistForControlFlow(&then_branch.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
|
@ -188,8 +188,8 @@ void EliminateUnusedResultsForIfCase(Operation *op, ArrayRef<FuncOp> branches) {
|
||||
|
||||
// Eliminated unused results from a functional while.
|
||||
void EliminateUnusedResultsForWhile(TF::WhileOp op) {
|
||||
FuncOp cond = op.cond_func();
|
||||
FuncOp body = op.body_func();
|
||||
FuncOp cond = op.cond_function();
|
||||
FuncOp body = op.body_function();
|
||||
|
||||
llvm::BitVector can_eliminate(op.getNumResults());
|
||||
for (OpResult result : llvm::reverse(op.getResults())) {
|
||||
@ -304,14 +304,14 @@ LogicalResult CanonicalizeFunctionalIfCase(Operation *op,
|
||||
// Canonicalizes a functional while. Forwards common argument to results and
|
||||
// drop resource results if posible.
|
||||
LogicalResult CanonicalizeFunctionalWhile(TF::WhileOp op) {
|
||||
for (FuncOp func : {op.cond_func(), op.body_func()}) {
|
||||
for (FuncOp func : {op.cond_function(), op.body_function()}) {
|
||||
if (failed(CleanupAndCanonicalize(func))) return failure();
|
||||
}
|
||||
|
||||
// For while, just use the body function to forward operand to result.
|
||||
bool has_resource_result = false;
|
||||
if (failed(ForwardCommonArgToOutput(op, {op.body_func()}, op.getOperands(),
|
||||
has_resource_result)))
|
||||
if (failed(ForwardCommonArgToOutput(op, {op.body_function()},
|
||||
op.getOperands(), has_resource_result)))
|
||||
return failure();
|
||||
// If no resource type results were found, no further cleanup needed.
|
||||
if (!has_resource_result) return success();
|
||||
@ -412,7 +412,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
|
||||
|
||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||
result = CanonicalizeFunctionalIfCase(
|
||||
op, {if_op.then_func(), if_op.else_func()}, if_op.input());
|
||||
op, {if_op.then_function(), if_op.else_function()}, if_op.input());
|
||||
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
|
||||
SmallVector<FuncOp, 4> branches;
|
||||
for (Attribute branch : case_op.branches()) {
|
||||
@ -422,7 +422,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) {
|
||||
}
|
||||
result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input());
|
||||
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
if (while_op.cond_func().walk(check_while_cond).wasInterrupted())
|
||||
if (while_op.cond_function().walk(check_while_cond).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
result = CanonicalizeFunctionalWhile(while_op);
|
||||
} else if (isa<TF::IfRegionOp, TF::CaseRegionOp, tf_device::ClusterOp>(
|
||||
|
@ -241,8 +241,8 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
|
||||
// function result types.
|
||||
bool InferShapeForIf(IfOp op) {
|
||||
bool changed = false;
|
||||
auto then_results = op.then_func().getType().getResults();
|
||||
auto else_results = op.else_func().getType().getResults();
|
||||
auto then_results = op.then_function().getType().getResults();
|
||||
auto else_results = op.else_function().getType().getResults();
|
||||
for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
|
||||
// If then and else types do not match, skip refinement for that result.
|
||||
if (std::get<1>(it) != std::get<2>(it)) continue;
|
||||
@ -924,7 +924,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||
return PropagateShapeToFunctions(
|
||||
module, drop_begin(if_op.getOperandTypes(), 1),
|
||||
{if_op.then_func(), if_op.else_func()}, max_iteration);
|
||||
{if_op.then_function(), if_op.else_function()}, max_iteration);
|
||||
} else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
|
||||
SmallVector<FuncOp, 4> branches;
|
||||
for (Attribute branch : case_op.branches()) {
|
||||
@ -937,7 +937,7 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
||||
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
return PropagateShapeToFunctions(
|
||||
module, while_op.getOperandTypes(),
|
||||
{while_op.cond_func(), while_op.body_func()}, max_iteration);
|
||||
{while_op.cond_function(), while_op.body_function()}, max_iteration);
|
||||
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
|
||||
if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) {
|
||||
PropagateConstantToCallee(call_op, func, module);
|
||||
|
@ -163,7 +163,7 @@ LogicalResult HandleWhileOp(
|
||||
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
|
||||
llvm::StringMap<PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto body = while_op.body_func();
|
||||
auto body = while_op.body_function();
|
||||
llvm::SmallDenseMap<Value, Value> body_map;
|
||||
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
|
||||
auto it = data_var_to_size_var.find(while_op.getOperand(index));
|
||||
@ -187,7 +187,7 @@ LogicalResult HandleWhileOp(
|
||||
return failure();
|
||||
}
|
||||
// Cond should not change stacks in the arguments, so use an empty map.
|
||||
auto cond = while_op.cond_func();
|
||||
auto cond = while_op.cond_function();
|
||||
ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
|
||||
llvm::SmallDenseMap<Value, Value> empty_map;
|
||||
if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
|
||||
@ -231,8 +231,8 @@ LogicalResult HandleIfOp(
|
||||
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
|
||||
llvm::StringMap<PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto then_func = if_op.then_func();
|
||||
auto else_func = if_op.else_func();
|
||||
auto then_func = if_op.then_function();
|
||||
auto else_func = if_op.else_function();
|
||||
llvm::SmallDenseMap<Value, Value> then_map;
|
||||
llvm::SmallDenseMap<Value, Value> else_map;
|
||||
|
||||
|
@ -443,12 +443,12 @@ llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
|
||||
insert(grad.handle(), grad.source().str());
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
for (const auto& entry : AccessedGradients(
|
||||
{while_op.body_func(), while_op.cond_func()}, module))
|
||||
{while_op.body_function(), while_op.cond_function()}, module))
|
||||
for (const string& source : entry.getSecond())
|
||||
insert(while_op.getOperand(entry.getFirst()), source);
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
for (const auto& entry :
|
||||
AccessedGradients({if_op.then_func(), if_op.else_func()}, module))
|
||||
for (const auto& entry : AccessedGradients(
|
||||
{if_op.then_function(), if_op.else_function()}, module))
|
||||
for (const string& source : entry.getSecond())
|
||||
insert(if_op.getOperand(entry.getFirst() + 1), source);
|
||||
} else if (auto call = llvm::dyn_cast<CallOpInterface>(&op)) {
|
||||
@ -509,8 +509,8 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto body = while_op.body_func();
|
||||
auto cond = while_op.cond_func();
|
||||
auto body = while_op.body_function();
|
||||
auto cond = while_op.cond_function();
|
||||
auto grads = AccessedGradients({body, cond}, module);
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(while_op.getOperand(index));
|
||||
@ -592,8 +592,8 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto then_branch = if_op.then_func();
|
||||
auto else_branch = if_op.else_func();
|
||||
auto then_branch = if_op.then_function();
|
||||
auto else_branch = if_op.else_function();
|
||||
auto grads = AccessedGradients({then_branch, else_branch}, module);
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(if_op.getOperand(index + 1));
|
||||
|
@ -155,7 +155,7 @@ LogicalResult HandleWhileOp(
|
||||
llvm::StringMap<PartitionedCallDecompositionInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
// Rewrite body.
|
||||
auto body = while_op.body_func();
|
||||
auto body = while_op.body_function();
|
||||
llvm::SmallDenseMap<Value, SizeInfo> body_map;
|
||||
auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
|
||||
auto it = buffer_to_size->find(while_op.getOperand(index));
|
||||
@ -176,7 +176,7 @@ LogicalResult HandleWhileOp(
|
||||
auto output_buffer_to_size = AddTensorListSizesToReturn(body, body_map);
|
||||
|
||||
// Rewrite cond.
|
||||
auto cond = while_op.cond_func();
|
||||
auto cond = while_op.cond_function();
|
||||
llvm::SmallDenseMap<Value, SizeInfo> cond_map;
|
||||
ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
|
||||
find_arg_tensor_list_type, arg_buffer_size_is_fixed);
|
||||
@ -701,9 +701,9 @@ LogicalResult DecomposeTensorListOpsInternal(
|
||||
return failure();
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
if (failed(HandleCaseOrIfOp(if_op, {if_op.then_func(), if_op.else_func()},
|
||||
module, buffer_to_size,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
if (failed(HandleCaseOrIfOp(
|
||||
if_op, {if_op.then_function(), if_op.else_function()}, module,
|
||||
buffer_to_size, decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
|
||||
|
@ -452,8 +452,8 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
!llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
|
||||
return;
|
||||
|
||||
FuncOp body = while_op.body_func();
|
||||
FuncOp cond = while_op.cond_func();
|
||||
FuncOp body = while_op.body_function();
|
||||
FuncOp cond = while_op.cond_function();
|
||||
|
||||
// Analyze the formattable inputs.
|
||||
auto execute_arg_to_outer_args =
|
||||
|
@ -119,8 +119,8 @@ void LowerIf(TF::IfOp op, ModuleOp module) {
|
||||
// Import the regions for both the true and false cases. These regions
|
||||
// must be updated to tuple the return results together and use the xla hlo
|
||||
// return op.
|
||||
ImportXlaRegion(op.then_func(), &if_op.true_branch(), loc);
|
||||
ImportXlaRegion(op.else_func(), &if_op.false_branch(), loc);
|
||||
ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc);
|
||||
ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc);
|
||||
|
||||
// De-tuple the results of the xla hlo if result.
|
||||
Detuple(if_op.getResult(), op.getResults(), &builder);
|
||||
@ -172,8 +172,8 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) {
|
||||
|
||||
// Import the regions for both the cond and body. These regions must be
|
||||
// updated to tuple the return results together and use the xla hlo return op.
|
||||
ImportXlaRegion(op.body_func(), &while_op.body(), loc);
|
||||
ImportXlaRegion(op.cond_func(), &while_op.cond(), loc,
|
||||
ImportXlaRegion(op.body_function(), &while_op.body(), loc);
|
||||
ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
|
||||
/*tuple_return=*/false);
|
||||
|
||||
// De-tuple the results of the xla hlo while.
|
||||
|
Loading…
x
Reference in New Issue
Block a user