NFC: Remove the various "::getFunction" methods.
These methods assume that a function is a valid builtin top-level operation, and removing these methods allows for decoupling FuncOp and IR/. Utility "getParentOfType" methods have been added to Operation/OpState to allow for querying the first parent operation of a given type. PiperOrigin-RevId: 257018913
This commit is contained in:
parent
38f264b268
commit
355f7167d7
@ -267,7 +267,7 @@ LogicalResult IfOp::verify() {
|
||||
auto elseAttr = getAttrOfType<FunctionAttr>("else_branch");
|
||||
if (!elseAttr) return emitOpError("requires else_branch attribute");
|
||||
|
||||
auto module = getOperation()->getFunction().getModule();
|
||||
auto module = getParentOfType<Module>();
|
||||
auto thenFn = module.getNamedFunction(thenAttr.getValue());
|
||||
if (!thenFn)
|
||||
return emitOpError("then_branch refers to an undefined function : ")
|
||||
@ -716,7 +716,7 @@ LogicalResult WhileOp::verify() {
|
||||
auto condAttr = getAttrOfType<FunctionAttr>("cond");
|
||||
if (!condAttr) return emitOpError("requires cond attribute");
|
||||
|
||||
auto module = getOperation()->getFunction().getModule();
|
||||
auto module = getParentOfType<Module>();
|
||||
auto condFn = module.getNamedFunction(condAttr.getValue());
|
||||
auto condFuncType = condFn.getType();
|
||||
|
||||
|
@ -151,7 +151,7 @@ static bool LowerIfOp(IfOp op) {
|
||||
Value* cond_i1 = LowerCondition(loc, op.getCondition(), &builder);
|
||||
if (!cond_i1) return true;
|
||||
|
||||
auto module = op_inst->getFunction().getModule();
|
||||
auto module = op_inst->getParentOfType<Module>();
|
||||
auto then_fn = module.getNamedFunction(op.getThen());
|
||||
auto else_fn = module.getNamedFunction(op.getElse());
|
||||
|
||||
@ -208,7 +208,7 @@ static bool LowerWhileOp(WhileOp op) {
|
||||
|
||||
OpBuilder builder(op_inst);
|
||||
|
||||
auto module = op_inst->getFunction().getModule();
|
||||
auto module = op_inst->getParentOfType<Module>();
|
||||
auto cond_fn = module.getNamedFunction(op.getCond());
|
||||
auto body_fn = module.getNamedFunction(op.getBody());
|
||||
|
||||
|
@ -194,8 +194,10 @@ std::string Exporter::UniqueName(mlir::Operation* op) {
|
||||
StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
|
||||
mlir::BlockArgument* arg, unsigned index) {
|
||||
auto node_def = absl::make_unique<NodeDef>();
|
||||
node_def->set_name(
|
||||
UniqueName(arg->getOwner()->getFunction().getName().str()));
|
||||
node_def->set_name(UniqueName(arg->getContainingRegion()
|
||||
->getParentOfType<mlir::FuncOp>()
|
||||
.getName()
|
||||
.str()));
|
||||
node_def->set_op(FunctionLibraryDefinition::kArgOp);
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertToDataType(
|
||||
@ -213,7 +215,8 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
|
||||
mlir::Operation* inst, unsigned index) {
|
||||
auto node_def = absl::make_unique<NodeDef>();
|
||||
auto* inst_op = inst->getOperand(index);
|
||||
node_def->set_name(UniqueName(inst->getFunction().getName().str()));
|
||||
node_def->set_name(
|
||||
UniqueName(inst->getParentOfType<mlir::FuncOp>().getName().str()));
|
||||
node_def->set_op(FunctionLibraryDefinition::kRetOp);
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertToDataType(
|
||||
@ -316,7 +319,8 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) {
|
||||
// is an input node. We recover the original input node and skip adding the
|
||||
// argument node. The new input node will be handled as normal in the
|
||||
// following steps.
|
||||
if (arg->getFunction().getName() == "main") {
|
||||
if (arg->getContainingRegion()->getParentOfType<mlir::FuncOp>().getName() ==
|
||||
"main") {
|
||||
if (!arg->hasOneUse()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Arg in 'main' should only have one user.");
|
||||
|
@ -795,7 +795,8 @@ Status Importer::ConvertFunctionArgAndRets(
|
||||
"max", builder_->getF32FloatAttr(input_spec.max_value)));
|
||||
state.attributes.push_back(builder_->getNamedAttr(
|
||||
"type", builder_->getTypeAttr(final_type)));
|
||||
bb->getFunction().setAttr("tf.quantize", builder_->getUnitAttr());
|
||||
inst->getParentOfType<mlir::FuncOp>().setAttr(
|
||||
"tf.quantize", builder_->getUnitAttr());
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user