diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 3889e93878a..b4469b41cd8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -267,7 +267,7 @@ LogicalResult IfOp::verify() { auto elseAttr = getAttrOfType("else_branch"); if (!elseAttr) return emitOpError("requires else_branch attribute"); - auto module = getOperation()->getFunction().getModule(); + auto module = getParentOfType(); auto thenFn = module.getNamedFunction(thenAttr.getValue()); if (!thenFn) return emitOpError("then_branch refers to an undefined function : ") @@ -716,7 +716,7 @@ LogicalResult WhileOp::verify() { auto condAttr = getAttrOfType("cond"); if (!condAttr) return emitOpError("requires cond attribute"); - auto module = getOperation()->getFunction().getModule(); + auto module = getParentOfType(); auto condFn = module.getNamedFunction(condAttr.getValue()); auto condFuncType = condFn.getType(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index 18dea5647de..3e34fb53f4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -151,7 +151,7 @@ static bool LowerIfOp(IfOp op) { Value* cond_i1 = LowerCondition(loc, op.getCondition(), &builder); if (!cond_i1) return true; - auto module = op_inst->getFunction().getModule(); + auto module = op_inst->getParentOfType(); auto then_fn = module.getNamedFunction(op.getThen()); auto else_fn = module.getNamedFunction(op.getElse()); @@ -208,7 +208,7 @@ static bool LowerWhileOp(WhileOp op) { OpBuilder builder(op_inst); - auto module = op_inst->getFunction().getModule(); + auto module = op_inst->getParentOfType(); auto cond_fn = module.getNamedFunction(op.getCond()); auto body_fn = module.getNamedFunction(op.getBody()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 806a20003b3..0efbdb7b55b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -194,8 +194,10 @@ std::string Exporter::UniqueName(mlir::Operation* op) { StatusOr> Exporter::GetArgumentNode( mlir::BlockArgument* arg, unsigned index) { auto node_def = absl::make_unique(); - node_def->set_name( - UniqueName(arg->getOwner()->getFunction().getName().str())); + node_def->set_name(UniqueName(arg->getContainingRegion() + ->getParentOfType() + .getName() + .str())); node_def->set_op(FunctionLibraryDefinition::kArgOp); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( @@ -213,7 +215,8 @@ StatusOr> Exporter::GetReturnNode( mlir::Operation* inst, unsigned index) { auto node_def = absl::make_unique(); auto* inst_op = inst->getOperand(index); - node_def->set_name(UniqueName(inst->getFunction().getName().str())); + node_def->set_name( + UniqueName(inst->getParentOfType().getName().str())); node_def->set_op(FunctionLibraryDefinition::kRetOp); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( @@ -316,7 +319,8 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) { // is an input node. We recover the original input node and skip adding the // argument node. The new input node will be handled as normal in the // following steps. - if (arg->getFunction().getName() == "main") { + if (arg->getContainingRegion()->getParentOfType().getName() == + "main") { if (!arg->hasOneUse()) { return errors::FailedPrecondition( "Arg in 'main' should only have one user."); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc index 5d084a1406e..c0553de183e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc @@ -795,7 +795,8 @@ Status Importer::ConvertFunctionArgAndRets( "max", builder_->getF32FloatAttr(input_spec.max_value))); state.attributes.push_back(builder_->getNamedAttr( "type", builder_->getTypeAttr(final_type))); - bb->getFunction().setAttr("tf.quantize", builder_->getUnitAttr()); + inst->getParentOfType().setAttr( + "tf.quantize", builder_->getUnitAttr()); } }