diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD
index 6a47be332d0..66283bded71 100644
--- a/tensorflow/compiler/mlir/python/BUILD
+++ b/tensorflow/compiler/mlir/python/BUILD
@@ -37,6 +37,8 @@ cc_library(
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:protos_all_cc",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc
index 8bec288cda5..066726593a7 100644
--- a/tensorflow/compiler/mlir/python/mlir.cc
+++ b/tensorflow/compiler/mlir/python/mlir.cc
@@ -16,6 +16,7 @@ limitations under the License.
 #include <string>
 
 #include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/InitAllPasses.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
@@ -28,9 +29,40 @@ limitations under the License.
 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op.h"
 
 namespace tensorflow {
 
+namespace {
+
+// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
+// empty.
+std::string RunPassPipelineOnModule(mlir::ModuleOp module,
+                                    const std::string &pass_pipeline,
+                                    TF_Status *status) {
+  if (!pass_pipeline.empty()) {
+    mlir::PassManager pm(module.getContext());
+    std::string error;
+    llvm::raw_string_ostream error_stream(error);
+    if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
+      TF_SetStatus(status, TF_INVALID_ARGUMENT,
+                   ("Invalid pass_pipeline: " + error_stream.str()).c_str());
+      return "// error";
+    }
+
+    mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
+    if (failed(pm.run(module))) {
+      Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
+      return "// error";
+    }
+  }
+  return MlirModuleToString(module);
+}
+
+}  // anonymous namespace
+
 std::string ImportGraphDef(const std::string &proto,
                            const std::string &pass_pipeline,
                            TF_Status *status) {
@@ -49,24 +81,43 @@ std::string ImportGraphDef(const std::string &proto,
     return "// error";
   }
 
-  // Run the pass_pipeline on the module if not empty.
-  if (!pass_pipeline.empty()) {
-    mlir::PassManager pm(&context);
-    std::string error;
-    llvm::raw_string_ostream error_stream(error);
-    if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
-      TF_SetStatus(status, TF_INVALID_ARGUMENT,
-                   ("Invalid pass_pipeline: " + error_stream.str()).c_str());
-      return "// error";
-    }
+  return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
+}
 
-    mlir::StatusScopedDiagnosticHandler statusHandler(&context);
-    if (failed(pm.run(*module.ValueOrDie()))) {
-      Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
-      return "// error";
-    }
+std::string ImportFunction(const std::string &functiondef_proto,
+                           const std::string &functiondef_library_proto,
+                           const std::string &pass_pipeline,
+                           TF_Status *status) {
+  FunctionDef functiondef;
+  auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
+  if (!s.ok()) {
+    Set_TF_Status_from_Status(status, s);
+    return "// error";
   }
-  return MlirModuleToString(*module.ConsumeValueOrDie());
+
+  FunctionDefLibrary fdef_lib;
+  s = tensorflow::LoadProtoFromBuffer(functiondef_library_proto, &fdef_lib);
+  if (!s.ok()) {
+    Set_TF_Status_from_Status(status, s);
+    return "// error";
+  }
+
+  FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
+  s = flib_def.AddFunctionDef(functiondef);
+  if (!s.ok()) {
+    Set_TF_Status_from_Status(status, s);
+    return "// error";
+  }
+
+  const std::string &function_name = functiondef.signature().name();
+  mlir::MLIRContext context;
+  auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
+  if (!module.ok()) {
+    Set_TF_Status_from_Status(status, module.status());
+    return "// error";
+  }
+
+  return RunPassPipelineOnModule(module->get(), pass_pipeline, status);
 }
 
 std::string ExperimentalConvertSavedModelToMlir(
diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h
index e68ac28124b..6133068a5e8 100644
--- a/tensorflow/compiler/mlir/python/mlir.h
+++ b/tensorflow/compiler/mlir/python/mlir.h
@@ -25,13 +25,23 @@ limitations under the License.
 namespace tensorflow {
 
 // Simple wrapper to support tf.mlir.experimental.convert_graph_def.
-// Load a .pbptx, convert to MLIR, and (optionally) optimize the module before
-// returning it as a string.
+// Load a GraphDef (binary or textual proto format), convert to MLIR, and
+// (optionally) optimize the module before returning it as a string.
 // This is an early experimental API, ideally we should return a wrapper object
 // around a Python binding to the MLIR module.
 std::string ImportGraphDef(const std::string &proto,
                            const std::string &pass_pipeline, TF_Status *status);
 
+// Simple wrapper to support tf.mlir.experimental.convert_function.
+// Load FunctionDef and FunctionDefLibrary (binary or textual proto format),
+// convert to MLIR, and (optionally) optimize the module before returning it as
+// a string.
+// This is an early experimental API, ideally we should return a wrapper object
+// around a Python binding to the MLIR module.
+std::string ImportFunction(const std::string &functiondef_proto,
+                           const std::string &functiondef_library_proto,
+                           const std::string &pass_pipeline, TF_Status *status);
+
 // Load a SavedModel and return a textual MLIR string corresponding to it.
 //
 // Args:
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index b78f3112bdb..153c537589c 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -3654,6 +3654,8 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
   tensorflow::GraphDebugInfo dummy_debug_info;
   tensorflow::GraphImportConfig specs;
   specs.graph_as_function = true;
+  for (const auto* control_ret_node : fbody->control_ret_nodes)
+    specs.control_outputs.push_back(control_ret_node->name());
   return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
                                    flib_def, specs, name);
 }
diff --git a/tensorflow/python/compiler/mlir/BUILD b/tensorflow/python/compiler/mlir/BUILD
index fe59213837b..7e193795e60 100644
--- a/tensorflow/python/compiler/mlir/BUILD
+++ b/tensorflow/python/compiler/mlir/BUILD
@@ -11,7 +11,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/python:pywrap_mlir",
-        "//tensorflow/python:util",
+        "//tensorflow/python:tf_export",
     ],
 )
 
@@ -22,6 +22,10 @@ py_test(
     deps = [
         ":mlir",
         "//tensorflow/python:client_testlib",
-        "//tensorflow/python:platform",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:logging_ops",
+        "//tensorflow/python:tensor_spec",
+        "//tensorflow/python/eager:def_function",
     ],
 )
diff --git a/tensorflow/python/compiler/mlir/mlir.py b/tensorflow/python/compiler/mlir/mlir.py
index fd9918d19f8..3b72abc2850 100644
--- a/tensorflow/python/compiler/mlir/mlir.py
+++ b/tensorflow/python/compiler/mlir/mlir.py
@@ -26,6 +26,9 @@ from tensorflow.python.util.tf_export import tf_export
 def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
   """Import a GraphDef and convert it to a textual MLIR module.
 
+  This API is only intended for inspecting the internals of TensorFlow and the
+  string returned is at the moment intended for debugging purposes.
+
   Args:
     graph_def: An object of type graph_pb2.GraphDef or a textual proto
       representation of a valid GraphDef.
@@ -35,7 +38,51 @@ def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
 
   Returns:
     A textual representation of the MLIR module corresponding to the graphdef.
-    Raises a RuntimeError on error.
+
+  Raises:
+    InvalidArgumentError: if graph_def is invalid or cannot be converted to
+      MLIR.
 
   """
   return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
+
+
+@tf_export('mlir.experimental.convert_function')
+def convert_function(concrete_function, pass_pipeline='tf-standard-pipeline'):
+  """Import a ConcreteFunction and convert it to a textual MLIR module.
+
+  This API is only intended for inspecting the internals of TensorFlow and the
+  string returned is at the moment intended for debugging purposes.
+
+  A [tf.function](https://www.tensorflow.org/api_docs/python/tf/function) can be
+  imported and converted from TensorFlow to TensorFlow MLIR with this API by
+  extracting its ConcreteFunction (eagerly-executing wrapper around a
+  [tf.Graph](https://www.tensorflow.org/api_docs/python/tf/Graph)).
+
+  For example:
+  >>> @tf.function
+  ... def add(a, b):
+  ...   return a + b
+
+  >>> concrete_function = add.get_concrete_function(
+  ...     tf.TensorSpec(None, tf.dtypes.float32),
+  ...     tf.TensorSpec(None, tf.dtypes.float32))
+  >>> tf.mlir.experimental.convert_function(concrete_function)
+  '...module attributes {...} {...}'
+
+  Args:
+    concrete_function: An object of type ConcreteFunction.
+    pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
+      module, see MLIR documentation for the
+      [textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
+
+  Returns:
+    A textual representation of the MLIR module corresponding to the
+    ConcreteFunction.
+
+  Raises:
+    InvalidArgumentError: if concrete_function is invalid or cannot be converted
+      to MLIR.
+
+  """
+  return pywrap_mlir.import_function(concrete_function, pass_pipeline)
diff --git a/tensorflow/python/compiler/mlir/mlir_test.py b/tensorflow/python/compiler/mlir/mlir_test.py
index 2a2362d9f6b..9cb0063dc64 100644
--- a/tensorflow/python/compiler/mlir/mlir_test.py
+++ b/tensorflow/python/compiler/mlir/mlir_test.py
@@ -19,23 +19,68 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.compiler.mlir import mlir
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import logging_ops
 from tensorflow.python.platform import test
 
 
-class MLIRImportTest(test.TestCase):
+class MLIRGraphDefImportTest(test.TestCase):
 
-  def test_import_graph_def(self):
+  def testImport(self):
     """Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
     mlir_module = mlir.convert_graph_def('')
     # An empty graph should contain at least an empty main function.
     self.assertIn('func @main', mlir_module)
 
-  def test_invalid_pbtxt(self):
+  def testInvalidPbtxt(self):
     with self.assertRaisesRegex(errors.InvalidArgumentError,
                                 'Could not parse input proto'):
       mlir.convert_graph_def('some invalid proto')
 
 
+class MLIRConcreteFunctionImportTest(test.TestCase):
+
+  def testImport(self):
+
+    @def_function.function
+    def identity(i):
+      return i
+
+    concrete_function = identity.get_concrete_function(
+        tensor_spec.TensorSpec(None, dtypes.float32))
+    mlir_module = mlir.convert_function(concrete_function)
+    self.assertRegex(mlir_module, r'func @.*identity.*\(')
+
+  def testImportWithCall(self):
+
+    @def_function.function
+    def callee(i):
+      return i
+
+    @def_function.function
+    def caller(i):
+      return callee(i)
+
+    concrete_function = caller.get_concrete_function(
+        tensor_spec.TensorSpec(None, dtypes.float32))
+    mlir_module = mlir.convert_function(concrete_function)
+    self.assertRegex(mlir_module, r'func @.*caller.*\(')
+    self.assertRegex(mlir_module, r'func @.*callee.*\(')
+
+  def testImportWithControlRet(self):
+
+    @def_function.function
+    def logging():
+      logging_ops.print_v2('some message')
+
+    concrete_function = logging.get_concrete_function()
+    mlir_module = mlir.convert_function(concrete_function, pass_pipeline='')
+    self.assertRegex(mlir_module, r'tf\.PrintV2')
+    self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc
index 6bc0183fdc5..fa16e5872ee 100644
--- a/tensorflow/python/mlir_wrapper.cc
+++ b/tensorflow/python/mlir_wrapper.cc
@@ -31,6 +31,17 @@ PYBIND11_MODULE(_pywrap_mlir, m) {
           return output;
         });
 
+  m.def("ImportFunction", [](const std::string &functiondef,
+                             const std::string &functiondef_library,
+                             const std::string &pass_pipeline) {
+    tensorflow::Safe_TF_StatusPtr status =
+        tensorflow::make_safe(TF_NewStatus());
+    std::string output = tensorflow::ImportFunction(
+        functiondef, functiondef_library, pass_pipeline, status.get());
+    tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
+    return output;
+  });
+
   m.def("ExperimentalConvertSavedModelToMlir",
         [](const std::string &saved_model_path,
            const std::string &exported_names, bool show_debug_info) {
diff --git a/tensorflow/python/pywrap_mlir.py b/tensorflow/python/pywrap_mlir.py
index a8a8181ce48..82048140e16 100644
--- a/tensorflow/python/pywrap_mlir.py
+++ b/tensorflow/python/pywrap_mlir.py
@@ -29,6 +29,13 @@ def import_graphdef(graphdef, pass_pipeline):
       pass_pipeline.encode('utf-8'))
 
 
+def import_function(concrete_function, pass_pipeline):
+  return ImportFunction(
+      str(concrete_function.function_def).encode('utf-8'),
+      str(concrete_function.graph.as_graph_def().library).encode('utf-8'),
+      pass_pipeline.encode('utf-8'))
+
+
 def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
                                              show_debug_info):
   return ExperimentalConvertSavedModelToMlir(
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt
index e268fcf8e73..7a140a13bc6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.mlir.experimental.pbtxt
@@ -1,5 +1,9 @@
 path: "tensorflow.mlir.experimental"
 tf_module {
+  member_method {
+    name: "convert_function"
+    argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
+  }
   member_method {
     name: "convert_graph_def"
     argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt
index e268fcf8e73..7a140a13bc6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.mlir.experimental.pbtxt
@@ -1,5 +1,9 @@
 path: "tensorflow.mlir.experimental"
 tf_module {
+  member_method {
+    name: "convert_function"
+    argspec: "args=[\'concrete_function\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "
+  }
   member_method {
     name: "convert_graph_def"
     argspec: "args=[\'graph_def\', \'pass_pipeline\'], varargs=None, keywords=None, defaults=[\'tf-standard-pipeline\'], "