From 73ed7102b76ba5ac43ac6b327e717ad3a7734364 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 3 Jul 2019 13:21:24 -0700 Subject: [PATCH] Replace the implementation of Function and Module with FuncOp and ModuleOp. This is an important step in allowing for the top-level of the IR to be extensible. FuncOp and ModuleOp contain all of the necessary functionality, while using the existing operation infrastructure. As an interim step, many of the usages of Function and Module, including the name, will remain the same. In the future, many of these will be relaxed to allow for many different types of top-level operations to co-exist. PiperOrigin-RevId: 256427100 --- .../mlir/lite/transforms/lower_static_tensor_list.cc | 2 +- .../tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt | 2 +- .../tensorflow/transforms/tf_graph_optimization_pass.cc | 8 +++++--- .../compiler/mlir/tensorflow/translate/export_graphdef.cc | 8 +++++--- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index c4d8464d3d8..45849d90261 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -383,7 +383,7 @@ void LowerStaticTensorListPass::runOnModule() { std::vector funcs_in_module; for (auto func : getModule().getFunctions()) { // Always place the main function to be the first in the list. - if (func.getName().is("main")) { + if (func.getName() == "main") { funcs_in_module.insert(funcs_in_module.begin(), func); } else { funcs_in_module.push_back(func); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index b1746dc2319..ac84234e4ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -203,7 +203,7 @@ versions { producer: 27 } -# CHECK: func @main() loc(unknown) { +# CHECK: func @main() { # CHECK-NEXT: %0:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control) loc("while/NextIteration") # CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<0> : tensor} : () -> (tensor, !_tf.control) loc("Const") # CHECK-NEXT: %2:2 = "_tf.Enter"(%1#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor) -> (tensor<*xi32>, !_tf.control) loc("while/Enter") diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 1f75362117f..c835ea64158 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -144,9 +144,11 @@ void GraphOptPass::runOnModule() { auto module_out = std::move(module_or_status).ValueOrDie(); // We cannot replace the module in a ModulePass. So we simply copy the - // Function list from module_out to module_in. - module_in.clear(); - module_in.splice(module_in.getFunctions().end(), *module_out); + // operation list from module_out to module_in. + auto& module_in_ops = module_in.getBody()->getOperations(); + module_in_ops.clear(); + module_in_ops.splice(module_in_ops.end(), + module_out->getBody()->getOperations()); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index a2823826fe4..65a4cde1dcd 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -310,7 +310,7 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) { // is an input node. We recover the original input node and skip adding the // argument node. The new input node will be handled as normal in the // following steps. - if (arg->getFunction().getName().is("main")) { + if (arg->getFunction().getName() == "main") { if (!arg->hasOneUse()) { return errors::FailedPrecondition( "Arg in 'main' should only have one user."); @@ -507,8 +507,10 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs, // Ignore the gradient attribute on the function as it gets converted to // GradientDef. absl::flat_hash_set attrs_to_ignore = {grad_string}; - TF_RETURN_IF_ERROR(ConvertAttributes(function.getAttrs(), attrs_to_ignore, - func_def.mutable_attr())); + llvm::SmallVector funcAttrs( + function.getDialectAttrs()); + TF_RETURN_IF_ERROR( + ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr())); (*flib->add_function()) = func_def; return Status::OK(); }