diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index c4d8464d3d8..45849d90261 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -383,7 +383,7 @@ void LowerStaticTensorListPass::runOnModule() { std::vector funcs_in_module; for (auto func : getModule().getFunctions()) { // Always place the main function to be the first in the list. - if (func.getName().is("main")) { + if (func.getName() == "main") { funcs_in_module.insert(funcs_in_module.begin(), func); } else { funcs_in_module.push_back(func); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index b1746dc2319..ac84234e4ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -203,7 +203,7 @@ versions { producer: 27 } -# CHECK: func @main() loc(unknown) { +# CHECK: func @main() { # CHECK-NEXT: %0:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control) loc("while/NextIteration") # CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<0> : tensor} : () -> (tensor, !_tf.control) loc("Const") # CHECK-NEXT: %2:2 = "_tf.Enter"(%1#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor) -> (tensor<*xi32>, !_tf.control) loc("while/Enter") diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 1f75362117f..c835ea64158 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -144,9 +144,11 @@ void GraphOptPass::runOnModule() { auto module_out = std::move(module_or_status).ValueOrDie(); // We cannot replace the module in a ModulePass. So we simply copy the - // Function list from module_out to module_in. - module_in.clear(); - module_in.splice(module_in.getFunctions().end(), *module_out); + // operation list from module_out to module_in. + auto& module_in_ops = module_in.getBody()->getOperations(); + module_in_ops.clear(); + module_in_ops.splice(module_in_ops.end(), + module_out->getBody()->getOperations()); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index a2823826fe4..65a4cde1dcd 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -310,7 +310,7 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) { // is an input node. We recover the original input node and skip adding the // argument node. The new input node will be handled as normal in the // following steps. - if (arg->getFunction().getName().is("main")) { + if (arg->getFunction().getName() == "main") { if (!arg->hasOneUse()) { return errors::FailedPrecondition( "Arg in 'main' should only have one user."); @@ -507,8 +507,10 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs, // Ignore the gradient attribute on the function as it gets converted to // GradientDef. absl::flat_hash_set attrs_to_ignore = {grad_string}; - TF_RETURN_IF_ERROR(ConvertAttributes(function.getAttrs(), attrs_to_ignore, - func_def.mutable_attr())); + llvm::SmallVector funcAttrs( + function.getDialectAttrs()); + TF_RETURN_IF_ERROR( + ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr())); (*flib->add_function()) = func_def; return Status::OK(); }