From e3e22538e801b829eec862795db27d40c8db84ad Mon Sep 17 00:00:00 2001
From: Amit Patankar <amitpatankar@google.com>
Date: Mon, 3 Feb 2020 13:07:13 -0800
Subject: [PATCH] Export grappler code from C++ to Python with pybind11 instead
 of swig. Adding pywrap_required_hdrs to downstream grappler build files which
 will go away with bazel first class shared library support.

PiperOrigin-RevId: 292986813
Change-Id: I24cbeba85e593fcb4604913874c0ac01eac73e4d
---
 tensorflow/cc/BUILD                           |  10 +
 tensorflow/core/grappler/BUILD                |  12 +
 tensorflow/core/grappler/clusters/BUILD       |  13 +
 tensorflow/core/grappler/costs/BUILD          |  19 +
 tensorflow/core/grappler/optimizers/BUILD     |  11 +
 tensorflow/core/grappler/utils/BUILD          |  10 +
 tensorflow/core/grappler/verifiers/BUILD      |  10 +
 tensorflow/python/BUILD                       | 138 +++++-
 tensorflow/python/grappler/cluster.i          | 450 ------------------
 tensorflow/python/grappler/cluster.py         |  10 +-
 tensorflow/python/grappler/cluster_wrapper.cc | 332 +++++++++++++
 tensorflow/python/grappler/cost_analyzer.i    |  67 ---
 tensorflow/python/grappler/cost_analyzer.py   |   9 +-
 .../python/grappler/cost_analyzer_wrapper.cc  |  58 +++
 tensorflow/python/grappler/item.i             | 315 ------------
 tensorflow/python/grappler/item.py            |   9 +-
 tensorflow/python/grappler/item_wrapper.cc    | 245 ++++++++++
 tensorflow/python/grappler/tf_optimizer.i     | 144 ------
 tensorflow/python/grappler/tf_optimizer.py    |  16 +-
 .../python/grappler/tf_optimizer_wrapper.cc   | 108 +++++
 tensorflow/python/tensorflow.i                |   9 +-
 .../tools/def_file_filter/symbols_pybind.txt  |  56 +++
 22 files changed, 1034 insertions(+), 1017 deletions(-)
 delete mode 100644 tensorflow/python/grappler/cluster.i
 create mode 100644 tensorflow/python/grappler/cluster_wrapper.cc
 delete mode 100644 tensorflow/python/grappler/cost_analyzer.i
 create mode 100644 tensorflow/python/grappler/cost_analyzer_wrapper.cc
 delete mode 100644 tensorflow/python/grappler/item.i
 create mode 100644 tensorflow/python/grappler/item_wrapper.cc
 delete mode 100644 tensorflow/python/grappler/tf_optimizer.i
 create mode 100644 tensorflow/python/grappler/tf_optimizer_wrapper.cc

diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index a9f429b8bd3..5251ccdf1c0 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -41,6 +41,16 @@ filegroup(
     ],
 )
 
+filegroup(
+    name = "pywrap_required_hdrs",
+    srcs = [
+        "training/coordinator.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__pkg__",
+    ],
+)
+
 cc_library(
     name = "gradients",
     srcs = [
diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index 3f79c023caf..f8ab8748285 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -4,6 +4,18 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
+filegroup(
+    name = "pywrap_required_hdrs",
+    srcs = [
+        "devices.h",
+        "grappler_item.h",
+        "grappler_item_builder.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__pkg__",
+    ],
+)
+
 cc_library(
     name = "op_types",
     srcs = ["op_types.cc"],
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index c42e6398a59..74c8837313b 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -53,6 +53,19 @@ tf_cc_test(
     ],
 )
 
+filegroup(
+    name = "pywrap_required_hdrs",
+    srcs = [
+        "cluster.h",
+        "single_machine.h",
+        "utils.h",
+        "virtual_cluster.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__pkg__",
+    ],
+)
+
 cc_library(
     name = "cluster",
     srcs = ["cluster.cc"],
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index d96ea650f3f..2d547b968fc 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -36,6 +36,25 @@ tf_pyclif_proto_library(
     visibility = ["//visibility:public"],
 )
 
+filegroup(
+    name = "pywrap_required_hdrs",
+    srcs = [
+        "analytical_cost_estimator.h",
+        "cost_estimator.h",
+        "graph_memory.h",
+        "graph_properties.h",
+        "measuring_cost_estimator.h",
+        "op_context.h",
+        "op_level_cost_estimator.h",
+        "utils.h",
+        "virtual_placer.h",
+        "virtual_scheduler.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__pkg__",
+    ],
+)
+
 cc_library(
     name = "graph_properties",
     srcs = ["graph_properties.cc"],
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 7b229aae315..d025543b661 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -11,6 +11,17 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
+filegroup(
+    name = "pywrap_required_hdrs",
+    srcs = [
+        "graph_optimizer.h",
+        "meta_optimizer.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__pkg__",
+    ],
+)
+
 cc_library(
     name = "static_schedule",
     srcs = ["static_schedule.cc"],
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 8941d5552b6..127bf465b3f 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -8,6 +8,16 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
+filegroup(
+    name = "pywrap_required_hdrs",
+    srcs = [
+        "topological_sort.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__pkg__",
+    ],
+)
+
 cc_library(
     name = "scc",
     srcs = ["scc.cc"],
diff --git a/tensorflow/core/grappler/verifiers/BUILD b/tensorflow/core/grappler/verifiers/BUILD
index 068dd6a9be4..939972b0617 100644
--- a/tensorflow/core/grappler/verifiers/BUILD
+++ b/tensorflow/core/grappler/verifiers/BUILD
@@ -4,6 +4,16 @@ package(
     licenses = ["notice"],  # Apache 2.0
 )
 
+filegroup(
+    name = "pywrap_required_hdrs",
+    srcs = [
+        "graph_verifier.h",
+    ],
+    visibility = [
+        "//tensorflow/python:__pkg__",
+    ],
+)
+
 cc_library(
     name = "graph_verifier",
     hdrs = [
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 5659bb597ec..22dc153fdbb 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -364,6 +364,40 @@ cc_library(
     ] + tf_protos_grappler(),
 )
 
+# Necessary for the pywrap inclusion below. Combining targets does not work
+# properly.
+tf_pybind_cc_library_wrapper(
+    name = "cost_analyzer_headers",
+    deps = [
+        ":cost_analyzer_lib",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_pywrap_cost_analyzer",
+    srcs = ["grappler/cost_analyzer_wrapper.cc"],
+    hdrs = [
+        "grappler/cost_analyzer.h",
+        "//tensorflow/cc:pywrap_required_hdrs",
+        "//tensorflow/core/grappler:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
+        "//tensorflow/core/public:session.h",
+        "//tensorflow/core/public:session_options.h",
+    ],
+    module_name = "_pywrap_cost_analyzer",
+    deps = [
+        ":cost_analyzer_headers",
+        ":pybind11_status",
+        "//tensorflow/core:core_cpu_headers_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_id",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@pybind11",
+    ],
+)
+
 cc_library(
     name = "model_analyzer_lib",
     srcs = ["grappler/model_analyzer.cc"],
@@ -380,14 +414,16 @@ cc_library(
 tf_python_pybind_extension(
     name = "_pywrap_model_analyzer",
     srcs = ["grappler/model_analyzer_wrapper.cc"],
-    hdrs = ["grappler/model_analyzer.h"],
+    hdrs = [
+        "grappler/model_analyzer.h",
+        "//tensorflow/core/grappler:pywrap_required_hdrs",
+    ],
     module_name = "_pywrap_model_analyzer",
     deps = [
-        ":model_analyzer_lib",
         ":pybind11_status",
+        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core/grappler:grappler_item_builder",
         "@pybind11",
     ],
 )
@@ -5652,10 +5688,6 @@ tf_py_wrap_cc(
     name = "pywrap_tensorflow_internal",
     srcs = ["tensorflow.i"],
     swig_includes = [
-        "grappler/cluster.i",
-        "grappler/cost_analyzer.i",
-        "grappler/item.i",
-        "grappler/tf_optimizer.i",
         "lib/core/strings.i",
         "platform/base.i",
     ],
@@ -5721,10 +5753,11 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
     ":cpp_python_util",  # util
     ":py_func_lib",  # py_func
     ":model_analyzer_lib",  # model_analyzer
-    "//tensorflow/core/util:port",  # util_port
+    ":cost_analyzer_lib",  # cost_analyzer
     "//tensorflow/stream_executor:stream_executor_pimpl",  # stat_summarizer
     "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",  # graph_analyzer
     "//tensorflow/core/profiler/internal:print_model_analysis",  # tfprof
+    "//tensorflow/core/util:port",  # util_port
     "//tensorflow/core:framework_internal_impl",  # op_def_registry
     "//tensorflow/core:lib_internal_impl",  # device_lib
     "//tensorflow/core:core_cpu_impl",  # device_lib
@@ -5749,6 +5782,20 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
     "//tensorflow/core:core_cpu_base_no_ops",  # tf_session
     "//tensorflow/c:python_api",  # tf_session
     "//tensorflow/python:tf_session_helper",  # tf_session
+    "//tensorflow/core/grappler:grappler_item",  # tf_item
+    "//tensorflow/core/grappler/costs:graph_properties",  # tf_item
+    "//tensorflow/core/grappler:grappler_item_builder",  # tf_item
+    "//tensorflow/core/grappler/utils:topological_sort",  # tf_item
+    "//tensorflow/core/grappler/costs:utils",  # tf_cluster
+    "//tensorflow/core/grappler/optimizers:meta_optimizer",  # tf_optimizer
+    "//tensorflow/core/grappler/clusters:cluster",  # tf_cluster
+    "//tensorflow/core/grappler/clusters:single_machine",  # tf_cluster
+    "//tensorflow/core/grappler/costs:op_level_cost_estimator",  # tf_cluster
+    "//tensorflow/core/grappler/clusters:virtual_cluster",  # tf_cluster
+    "//tensorflow/core/grappler/costs:graph_memory",  # tf_cluster
+    "//tensorflow/core/grappler:devices",  # tf_cluster
+    "//tensorflow/core/grappler/clusters:utils",  # tf_optimizer
+    "//tensorflow/core/grappler/costs:measuring_cost_estimator",  # tf_cluster
 ]
 
 # Filter the DEF file to reduce the number of symbols to 64K or less.
@@ -7359,11 +7406,32 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        ":pywrap_tensorflow_internal",
+        ":_pywrap_tf_item",
         "//tensorflow/core/grappler/costs:op_performance_data_py",
     ],
 )
 
+tf_python_pybind_extension(
+    name = "_pywrap_tf_item",
+    srcs = ["grappler/item_wrapper.cc"],
+    hdrs = [
+        "//tensorflow/cc:pywrap_required_hdrs",
+        "//tensorflow/core/grappler:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/utils:pywrap_required_hdrs",
+    ],
+    module_name = "_pywrap_tf_item",
+    deps = [
+        ":pybind11_status",
+        "@pybind11",
+        "//tensorflow/core:core_cpu_headers_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_id",
+        "//tensorflow/core:protos_all_cc",
+    ] + if_not_windows(["//tensorflow/core/grappler/costs:graph_properties"]),  # b/148556093,
+)
+
 tf_py_test(
     name = "item_test",
     size = "small",
@@ -7414,11 +7482,34 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        ":pywrap_tensorflow_internal",
+        ":_pywrap_tf_cluster",
         "//tensorflow/core/grappler/costs:op_performance_data_py",
     ],
 )
 
+tf_python_pybind_extension(
+    name = "_pywrap_tf_cluster",
+    srcs = ["grappler/cluster_wrapper.cc"],
+    hdrs = [
+        "//tensorflow/cc:pywrap_required_hdrs",
+        "//tensorflow/core/grappler:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/utils:pywrap_required_hdrs",
+    ],
+    module_name = "_pywrap_tf_cluster",
+    deps = [
+        ":pybind11_status",
+        "//tensorflow/core:core_cpu_headers_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_id",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/types:span",
+        "@pybind11",
+    ],
+)
+
 cuda_py_test(
     name = "cluster_test",
     size = "small",
@@ -7451,11 +7542,34 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        ":pywrap_tensorflow_internal",
+        ":_pywrap_tf_optimizer",
         ":tf_cluster",
     ],
 )
 
+tf_python_pybind_extension(
+    name = "_pywrap_tf_optimizer",
+    srcs = ["grappler/tf_optimizer_wrapper.cc"],
+    hdrs = [
+        "//tensorflow/cc:pywrap_required_hdrs",
+        "//tensorflow/core/grappler:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/clusters:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/costs:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/optimizers:pywrap_required_hdrs",
+        "//tensorflow/core/grappler/verifiers:pywrap_required_hdrs",
+    ],
+    module_name = "_pywrap_tf_optimizer",
+    deps = [
+        ":pybind11_status",
+        "//tensorflow/core:core_cpu_headers_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_id",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "@pybind11",
+    ],
+)
+
 tf_py_test(
     name = "tf_optimizer_test",
     size = "small",
@@ -7619,7 +7733,7 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
-        ":pywrap_tensorflow_internal",
+        ":_pywrap_cost_analyzer",
         ":tf_cluster",
         ":tf_item",
     ],
diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i
deleted file mode 100644
index e2fa8bcad40..00000000000
--- a/tensorflow/python/grappler/cluster.i
+++ /dev/null
@@ -1,450 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "tensorflow/python/platform/base.i"
-%include <std_shared_ptr.i>
-%include "item.i"
-
-// Wrap the cluster into an object that swig can manipulate. This ensures it will call the object
-// destructor upon garbage collection instead of leaking memory.
-struct GCluster {
-  std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
-};
-
-%{
-#include "tensorflow/core/protobuf/device_properties.pb.h"
-
-template <>
-bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
-  char* c_string;
-  Py_ssize_t py_size;
-  if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
-    // Python has raised an error (likely TypeError or UnicodeEncodeError).
-    return false;
-  }
-
-  tensorflow::NamedDevice named_device;
-  if (!named_device.ParseFromString(string(c_string, py_size))) {
-    PyErr_SetString(
-        PyExc_TypeError,
-        "The NamedDevice could not be parsed as a valid protocol buffer");
-    return false;
-  }
-  if (out) *out = named_device;
-  return true;
-}
-%}
-
-%typemap(in) const std::vector<tensorflow::NamedDevice>& (std::vector<tensorflow::NamedDevice> temp) {
-  if (!tf_vector_input_helper($input, &temp, &_PyObjAs<tensorflow::NamedDevice>)) {
-    SWIG_fail;
-  }
-  $1 = &temp;
-}
-
-%typemap(in) const tensorflow::NamedDevice& (tensorflow::NamedDevice temp) {
-  char* c_string;
-  Py_ssize_t py_size;
-  if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
-    // Python has raised an error (likely TypeError or UnicodeEncodeError).
-    SWIG_fail;
-  }
-
-  if (!temp.ParseFromString(string(c_string, py_size))) {
-    PyErr_SetString(
-        PyExc_TypeError,
-        "The NamedDevice could not be parsed as a valid protocol buffer");
-    SWIG_fail;
-  }
-  $1 = &temp;
-}
-
-%typemap(in) const tensorflow::RunMetadata& (tensorflow::RunMetadata temp) {
-  char* c_string;
-  Py_ssize_t py_size;
-  if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
-    // Python has raised an error (likely TypeError or UnicodeEncodeError).
-    SWIG_fail;
-  }
-
-  if (!temp.ParseFromString(string(c_string, py_size))) {
-    PyErr_SetString(
-        PyExc_TypeError,
-        "The RunMetadata could not be parsed as a valid protocol buffer");
-    SWIG_fail;
-  }
-  $1 = &temp;
-}
-
-%typemap(in) const string& (string temp) {
-  char *buf;
-  Py_ssize_t len;
-  if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) return NULL;
-  temp.assign(buf, len);
-  $1 = &temp;
-}
-
-%{
-#include <memory>
-#include <vector>
-#include "tensorflow/core/grappler/devices.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/clusters/single_machine.h"
-#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
-#include "tensorflow/core/grappler/costs/graph_memory.h"
-#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
-#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
-#include "tensorflow/core/grappler/costs/utils.h"
-#include "tensorflow/core/protobuf/device_properties.pb.h"
-#include "tensorflow/core/framework/kernel_def.pb.h"
-#include "tensorflow/core/framework/memory_types.h"
-
-// Provide the implementation of the GCluster struct here.
-struct GCluster {
-  GCluster() {}
-  GCluster(tensorflow::grappler::Cluster* cluster) : cluster_(cluster) {}
-
-  tensorflow::grappler::Cluster* operator->() const {
-    return cluster_.get();
-  }
-  tensorflow::grappler::Cluster* get() const {
-    return cluster_.get();
-  }
-  bool is_none() const {
-    return cluster_.get() == nullptr;
-  }
-
-  std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
-};
-
-
-static GCluster TF_NewCluster(bool allow_soft_placement,
-                   bool disable_detailed_stats, TF_Status* status) {
-  int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
-  int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
-  int timeout_s = 60 * 10;
-  tensorflow::grappler::Cluster* cluster_ =
-      new tensorflow::grappler::SingleMachine(
-          timeout_s, num_cpu_cores, num_gpus);
-  cluster_->DisableDetailedStats(disable_detailed_stats);
-  cluster_->AllowSoftPlacement(allow_soft_placement);
-  cluster_->SetNumWarmupSteps(10);
-  tensorflow::Status s = cluster_->Provision();
-  tensorflow::Set_TF_Status_from_Status(status, s);
-  return GCluster(cluster_);
-}
-
-static GCluster TF_NewVirtualCluster(
-    const std::vector<tensorflow::NamedDevice>& named_devices, TF_Status* status) {
-  std::unordered_map<string, tensorflow::DeviceProperties> devices;
-  for (const auto& named_device : named_devices) {
-    devices[named_device.name()]= named_device.properties();
-  }
-  tensorflow::grappler::Cluster* cluster_ =
-      new tensorflow::grappler::VirtualCluster(devices);
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  tensorflow::Status s = cluster_->Provision();
-  PyGILState_Release(gstate);
-  tensorflow::Set_TF_Status_from_Status(status, s);
-  return GCluster(cluster_);
-}
-
-static void TF_ShutdownCluster(GCluster cluster) {
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  cluster->Shutdown();
-  PyGILState_Release(gstate);
-}
-
-tensorflow::Status _GetOpPerformanceDataAndRunTime(
-    const tensorflow::grappler::GrapplerItem& item,
-    tensorflow::grappler::CostEstimator* cost_measure,
-    tensorflow::OpPerformanceList* op_performance_data,
-    tensorflow::grappler::Costs* costs) {
-  tensorflow::Status status = cost_measure->Initialize(item);
-  if (!status.ok()) return status;
-
-  tensorflow::RunMetadata run_metadata;
-  TF_RETURN_IF_ERROR(
-      cost_measure->PredictCosts(item.graph, &run_metadata, costs));
-
-  if (op_performance_data) {
-    *op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
-        run_metadata.cost_graph(), item.graph);
-  }
-  return tensorflow::Status::OK();
-}
-
-static PyObject* TF_ListDevices(GCluster cluster) {
-  const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* result = PyList_New(devices.size());
-  int i = 0;
-  for (auto& dev : devices) {
-    tensorflow::NamedDevice d;
-    d.set_name(dev.first);
-    *d.mutable_properties() = dev.second;
-    string dev_str = d.SerializeAsString();
-    PyObject* dev_obj = PyBytes_FromStringAndSize(dev_str.data(),
-                                                  dev_str.size());
-    PyList_SetItem(result, i, dev_obj);
-    ++i;
-  }
-  PyGILState_Release(gstate);
-  return result;
-}
-
-static PyObject* TF_ListAvailableOps() {
-  tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
-  std::vector<tensorflow::OpDef> ops;
-  registry->GetRegisteredOps(&ops);
-  std::vector<string> op_names;
-  for (const tensorflow::OpDef& op : ops) {
-    op_names.push_back(op.name());
-  }
-  std::sort(op_names.begin(), op_names.end());
-
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* result = PyList_New(op_names.size());
-  for (int i = 0; i < op_names.size(); ++i) {
-    PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str()));
-  }
-  PyGILState_Release(gstate);
-  return result;
-}
-
-static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
-  if (cluster.is_none() || item.is_none()) {
-    Py_RETURN_NONE;
-  }
-  const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
-  std::unordered_map<string, std::vector<string>> device_types;
-  for (const auto& dev : devices) {
-    device_types[dev.second.type()].push_back(dev.first);
-  }
-
-  std::unordered_map<string, std::set<string>> supported_device_types;
-  std::unordered_map<string, std::set<string>> device_restrictions;
-
-  for (const auto& node : item->graph.node()) {
-    for (const auto& dev : device_types) {
-      const string& type = dev.first;
-      if (cluster->type() != "single_machine") {
-        // The actual kernel may not be linked in this binary.
-        supported_device_types[node.name()].insert(type);
-      } else {
-        // Check the kernel capabilities
-        const tensorflow::DeviceType dev_type(type);
-        tensorflow::Status s = tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
-        if (s.ok()) {
-          supported_device_types[node.name()].insert(type);
-
-          // Check which inputs are restricted to reside on the host.
-          // TODO: extends this to support outputs as well
-          tensorflow::MemoryTypeVector inp_mtypes;
-          tensorflow::MemoryTypeVector out_mtypes;
-          s = tensorflow::MemoryTypesForNode(tensorflow::OpRegistry::Global(), dev_type, node,
-                                             &inp_mtypes, &out_mtypes);
-          if (s.ok()) {
-            for (int i = 0; i < inp_mtypes.size(); ++i) {
-              if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
-                device_restrictions[tensorflow::grappler::NodeName(node.input(i))].insert("CPU");
-                break;
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* result = PyDict_New();
-
-  for (const auto& supported_dev : supported_device_types) {
-    const string& node = supported_dev.first;
-    std::set<string> feasible;
-    const auto it = device_restrictions.find(node);
-    if (it != device_restrictions.end()) {
-      const std::set<string>& candidates = supported_dev.second;
-      const std::set<string>& valid = it->second;
-      std::set_intersection(candidates.begin(), candidates.end(), valid.begin(), valid.end(),
-                            std::inserter(feasible, feasible.begin()));
-    } else {
-      feasible = supported_dev.second;
-    }
-
-    std::vector<string> device_names;
-    for (const string& type : feasible) {
-      auto it = device_types.find(type);
-      CHECK(it != device_types.end());
-      for (const string& name : it->second) {
-        device_names.push_back(name);
-      }
-    }
-
-    PyObject* dev = PyList_New(device_names.size());
-    for (int i = 0; i < device_names.size(); ++i) {
-      PyList_SetItem(dev, i, PyString_FromString(device_names[i].c_str()));
-    }
-    CHECK_EQ(0, PyDict_SetItem(result, PyString_FromString(node.c_str()), dev));
-  }
-  PyGILState_Release(gstate);
-  return result;
-}
-
-
-static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
-  tensorflow::grappler::OpLevelCostEstimator estimator;
-  tensorflow::grappler::DeviceInfo info =
-      estimator.GetDeviceInfo(device.properties());
-  return info.gigaops;
-}
-
-static PyObject* TF_MeasureCosts(
-    GItem item,
-    GCluster cluster,
-    bool generate_timeline, TF_Status* status) {
-  tensorflow::OpPerformanceList op_performance_data;
-  tensorflow::StepStats step_stats;
-
-  const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
-  tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), num_measurements, 0);
-
-  tensorflow::grappler::Costs costs;
-  tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
-      *item, &cost_measure, &op_performance_data, &costs);
-  double run_time = FLT_MAX;
-  if (s.ok()) {
-    run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
-  }
-  if (generate_timeline) {
-    tensorflow::RunMetadata metadata;
-    tensorflow::Status run_status = cluster->Run(
-        item->graph, item->feed, item->fetch, &metadata);
-    if (run_status.ok()) {
-      step_stats = metadata.step_stats();
-    } else {
-      s = run_status;
-    }
-  }
-
-  tensorflow::Set_TF_Status_from_Status(status, s);
-  if (!s.ok()) {
-    Py_RETURN_NONE;
-  }
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* op_perf_objs = PyList_New(
-      op_performance_data.op_performance_size());
-  for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
-    string op_perf_str =
-        op_performance_data.op_performance(i).SerializeAsString();
-    PyObject* op_perf_obj = PyBytes_FromStringAndSize(op_perf_str.data(),
-                                                      op_perf_str.size());
-    PyList_SetItem(op_perf_objs, i, op_perf_obj);
-  }
-
-  PyObject* run_time_obj = PyFloat_FromDouble(run_time);
-
-  string step_stats_str = step_stats.SerializeAsString();
-  PyObject* metadata_obj = PyBytes_FromStringAndSize(step_stats_str.data(),
-                                                     step_stats_str.size());
-
-  PyObject* ret = PyTuple_New(3);
-  if (PyTuple_SetItem(ret, 0, op_perf_objs) != 0 ||
-      PyTuple_SetItem(ret, 1, run_time_obj) != 0 ||
-      PyTuple_SetItem(ret, 2, metadata_obj) != 0) {
-    Py_DECREF(ret);
-    Py_XDECREF(op_perf_objs);
-    Py_XDECREF(run_time_obj);
-    Py_XDECREF(metadata_obj);
-    s = tensorflow::Status(tensorflow::error::Code::INTERNAL,
-                           "Error setting return tuples.");
-    tensorflow::Set_TF_Status_from_Status(status, s);
-    Py_INCREF(Py_None);
-    ret = Py_None;
-  }
-  PyGILState_Release(gstate);
-  return ret;
-}
-
-
-static PyObject* TF_DeterminePeakMemoryUsage(
-    GItem item,
-    GCluster cluster,
-    TF_Status* status) {
-  if (item.is_none() || cluster.is_none()) {
-    tensorflow::Status s(tensorflow::error::Code::INTERNAL,
-                         "You need both a cluster and an item to determine peak memory usage");
-    tensorflow::Set_TF_Status_from_Status(status, s);
-    Py_RETURN_NONE;
-  }
-  tensorflow::grappler::GraphMemory memory(*item);
-
-  tensorflow::Status s;
-  if (cluster->DetailedStatsEnabled()) {
-    s = memory.InferDynamically(cluster.get());
-  } else {
-    s = memory.InferStatically(cluster->GetDevices());
-  }
-  if (!s.ok()) {
-    tensorflow::Set_TF_Status_from_Status(status, s);
-    Py_RETURN_NONE;
-  }
-
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* result = PyDict_New();
-  for (const auto& device : cluster->GetDevices()) {
-    const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
-        memory.GetPeakMemoryUsage(device.first);
-    PyObject* per_device = PyList_New(usage.live_tensors.size());
-    for (int i = 0; i < usage.live_tensors.size(); ++i) {
-      const auto& live_tensor = usage.live_tensors[i];
-      PyObject* live = PyTuple_New(5);
-      PyTuple_SetItem(live, 0, PyString_FromString(live_tensor.node.c_str()));
-      PyTuple_SetItem(live, 1, PyInt_FromLong(live_tensor.output_id));
-      PyTuple_SetItem(live, 2, PyLong_FromLong(live_tensor.memory_used));
-      PyTuple_SetItem(live, 3, PyLong_FromLong(live_tensor.allocation_time.count()));
-      PyTuple_SetItem(live, 4, PyLong_FromLong(live_tensor.deallocation_time.count()));
-      PyList_SetItem(per_device, i, live);
-
-    }
-    PyObject* ret = PyTuple_New(2);
-    PyTuple_SetItem(ret, 0, PyLong_FromLong(usage.used_memory));
-    PyTuple_SetItem(ret, 1, per_device);
-    PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret);
-  }
-  PyGILState_Release(gstate);
-  return result;
-}
-
-%}
-
-// Wrap these functions.
-static GCluster TF_NewCluster(
-    bool allow_soft_placement, bool disable_detailed_stats, TF_Status* status);
-static GCluster TF_NewVirtualCluster(
-    const std::vector<tensorflow::NamedDevice>& named_devices,
-    TF_Status* status);
-static void TF_ShutdownCluster(GCluster cluster);
-static PyObject* TF_ListDevices(GCluster cluster);
-static PyObject* TF_ListAvailableOps();
-static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
-static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
-static PyObject* TF_MeasureCosts(
-    GItem item, GCluster cluster,
-    bool generate_timeline, TF_Status* status);
-static PyObject* TF_DeterminePeakMemoryUsage(
-    GItem item, GCluster cluster,
-    TF_Status* status);
diff --git a/tensorflow/python/grappler/cluster.py b/tensorflow/python/grappler/cluster.py
index 5fd44a75305..a58388f25e9 100644
--- a/tensorflow/python/grappler/cluster.py
+++ b/tensorflow/python/grappler/cluster.py
@@ -23,7 +23,7 @@ import contextlib
 from tensorflow.core.framework import step_stats_pb2
 from tensorflow.core.grappler.costs import op_performance_data_pb2
 from tensorflow.core.protobuf import device_properties_pb2
-from tensorflow.python import pywrap_tensorflow as tf_cluster
+from tensorflow.python import _pywrap_tf_cluster as tf_cluster
 
 
 class Cluster(object):
@@ -92,13 +92,9 @@ class Cluster(object):
       item: The item for which to measure the costs.
     Returns: The triplet op_perfs, runtime, step_stats.
     """
-    ret_from_swig = tf_cluster.TF_MeasureCosts(item.tf_item, self._tf_cluster,
-                                               self._generate_timeline)
+    op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts(
+        item.tf_item, self._tf_cluster, self._generate_timeline)
 
-    if ret_from_swig is None:
-      return None
-
-    op_perf_bytes_list, run_time, step_stats_bytes = ret_from_swig
     op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
                 for op_perf_bytes in op_perf_bytes_list]
     return (op_perfs, run_time,
diff --git a/tensorflow/python/grappler/cluster_wrapper.cc b/tensorflow/python/grappler/cluster_wrapper.cc
new file mode 100644
index 00000000000..685a95ccf93
--- /dev/null
+++ b/tensorflow/python/grappler/cluster_wrapper.cc
@@ -0,0 +1,332 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cfloat>
+#include <cstdint>
+#include <memory>
+#include <set>
+#include <stdexcept>
+#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
+
+#include "include/pybind11/pybind11.h"
+#include "include/pybind11/stl.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/memory_types.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/clusters/single_machine.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/graph_memory.h"
+#include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/device_properties.pb.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+namespace py = pybind11;
+
+tensorflow::Status _GetOpPerformanceDataAndRunTime(
+    const tensorflow::grappler::GrapplerItem& item,
+    tensorflow::grappler::CostEstimator* cost_measure,
+    tensorflow::OpPerformanceList* op_performance_data,
+    tensorflow::grappler::Costs* costs) {
+  tensorflow::Status status = cost_measure->Initialize(item);
+  if (!status.ok()) return status;
+
+  tensorflow::RunMetadata run_metadata;
+  MaybeRaiseRegisteredFromStatus(
+      cost_measure->PredictCosts(item.graph, &run_metadata, costs));
+
+  if (op_performance_data) {
+    *op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
+        run_metadata.cost_graph(), item.graph);
+  }
+  return tensorflow::Status::OK();
+}
+
+PYBIND11_MAKE_OPAQUE(tensorflow::grappler::Cluster);
+
+PYBIND11_MODULE(_pywrap_tf_cluster, m) {
+  py::class_<tensorflow::grappler::Cluster> grappler_cluster(
+      m, "tensorflow::grappler::Cluster");
+
+  m.def("TF_NewCluster",
+        [](bool allow_soft_placement,
+           bool disable_detailed_stats) -> tensorflow::grappler::Cluster* {
+          // TODO(petebu): Make these named arguments with default values
+          // instead.
+          int num_cpu_cores =
+              tensorflow::grappler::GetNumAvailableLogicalCPUCores();
+          int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
+          int timeout_s = 60 * 10;
+          std::unique_ptr<tensorflow::grappler::Cluster> cluster =
+              std::make_unique<tensorflow::grappler::SingleMachine>(
+                  timeout_s, num_cpu_cores, num_gpus);
+          cluster->DisableDetailedStats(disable_detailed_stats);
+          cluster->AllowSoftPlacement(allow_soft_placement);
+          cluster->SetNumWarmupSteps(10);
+          MaybeRaiseRegisteredFromStatus(cluster->Provision());
+          return cluster.release();
+        });
+
+  m.def("TF_NewVirtualCluster",
+        [](const std::vector<py::bytes>& serialized_named_devices)
+            -> tensorflow::grappler::Cluster* {
+          std::vector<tensorflow::NamedDevice> named_devices;
+          for (const auto& s : serialized_named_devices) {
+            tensorflow::NamedDevice named_device;
+            if (!named_device.ParseFromString(s)) {
+              throw std::invalid_argument(
+                  "The NamedDevice could not be parsed as a valid protocol "
+                  "buffer");
+            }
+            named_devices.push_back(named_device);
+          }
+
+          std::unordered_map<std::string, tensorflow::DeviceProperties> devices;
+          for (const auto& named_device : named_devices) {
+            devices[named_device.name()] = named_device.properties();
+          }
+          std::unique_ptr<tensorflow::grappler::Cluster> cluster =
+              std::make_unique<tensorflow::grappler::VirtualCluster>(devices);
+          {
+            // TODO(petebu): Do we need to hold the GIL here?
+            py::gil_scoped_acquire acquire;
+            MaybeRaiseRegisteredFromStatus(cluster->Provision());
+          }
+          return cluster.release();
+        });
+
+  m.def("TF_ShutdownCluster", [](tensorflow::grappler::Cluster* cluster) {
+    // TODO(petebu): Do we need to hold the GIL here?
+    py::gil_scoped_acquire acquire;
+    cluster->Shutdown();
+  });
+
+  m.def("TF_ListDevices",
+        [](tensorflow::grappler::Cluster* cluster) -> std::vector<py::bytes> {
+          const std::unordered_map<std::string, tensorflow::DeviceProperties>&
+              devices = cluster->GetDevices();
+          std::vector<py::bytes> named_devices;
+          for (auto& dev : devices) {
+            tensorflow::NamedDevice d;
+            d.set_name(dev.first);
+            *d.mutable_properties() = dev.second;
+            named_devices.push_back(d.SerializeAsString());
+          }
+          return named_devices;
+        });
+
+  m.def("TF_ListAvailableOps", []() -> std::vector<std::string> {
+    tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
+    std::vector<tensorflow::OpDef> ops;
+    registry->GetRegisteredOps(&ops);
+    std::vector<std::string> op_names;
+    for (const tensorflow::OpDef& op : ops) {
+      op_names.push_back(op.name());
+    }
+    std::sort(op_names.begin(), op_names.end());
+    return op_names;
+  });
+
+  m.def(
+      "TF_GetSupportedDevices",
+      [](tensorflow::grappler::Cluster* cluster,
+         tensorflow::grappler::GrapplerItem* item)
+          -> std::unordered_map<std::string, std::vector<std::string>> {
+        if (cluster == nullptr || item == nullptr) {
+          MaybeRaiseRegisteredFromStatus(tensorflow::Status(
+              tensorflow::errors::Internal("You need both a cluster and an "
+                                           "item to get supported devices.")));
+        }
+        const std::unordered_map<std::string, tensorflow::DeviceProperties>&
+            devices = cluster->GetDevices();
+        std::unordered_map<std::string, std::vector<std::string>> device_types;
+        for (const auto& dev : devices) {
+          device_types[dev.second.type()].push_back(dev.first);
+        }
+
+        std::unordered_map<std::string, std::set<std::string>>
+            supported_device_types;
+        std::unordered_map<std::string, std::set<std::string>>
+            device_restrictions;
+
+        for (const auto& node : item->graph.node()) {
+          for (const auto& dev : device_types) {
+            const std::string& type = dev.first;
+            if (cluster->type() != "single_machine") {
+              // The actual kernel may not be linked in this binary.
+              supported_device_types[node.name()].insert(type);
+            } else {
+              // Check the kernel capabilities
+              const tensorflow::DeviceType dev_type(type);
+              tensorflow::Status s =
+                  tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
+              if (s.ok()) {
+                supported_device_types[node.name()].insert(type);
+
+                // Check which inputs are restricted to reside on the host.
+                // TODO: extends this to support outputs as well
+                tensorflow::MemoryTypeVector inp_mtypes;
+                tensorflow::MemoryTypeVector out_mtypes;
+                tensorflow::Status s = tensorflow::MemoryTypesForNode(
+                    tensorflow::OpRegistry::Global(), dev_type, node,
+                    &inp_mtypes, &out_mtypes);
+                if (s.ok()) {
+                  for (int i = 0; i < inp_mtypes.size(); ++i) {
+                    if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
+                      device_restrictions[tensorflow::grappler::NodeName(
+                                              node.input(i))]
+                          .insert("CPU");
+                      break;
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+
+        std::unordered_map<std::string, std::vector<std::string>> result;
+        for (const auto& supported_dev : supported_device_types) {
+          const std::string& node = supported_dev.first;
+          std::set<std::string> feasible;
+          const auto it = device_restrictions.find(node);
+          if (it != device_restrictions.end()) {
+            const std::set<std::string>& candidates = supported_dev.second;
+            const std::set<std::string>& valid = it->second;
+            std::set_intersection(candidates.begin(), candidates.end(),
+                                  valid.begin(), valid.end(),
+                                  std::inserter(feasible, feasible.begin()));
+          } else {
+            feasible = supported_dev.second;
+          }
+
+          std::vector<std::string> device_names;
+          for (const std::string& type : feasible) {
+            auto it = device_types.find(type);
+            DCHECK(it != device_types.end());
+            for (const std::string& name : it->second) {
+              device_names.push_back(name);
+            }
+          }
+          result[node] = device_names;
+        }
+        return result;
+      });
+
+  m.def("TF_EstimatePerformance", [](const py::bytes& serialized_device) {
+    tensorflow::NamedDevice device;
+    if (!device.ParseFromString(serialized_device)) {
+      throw std::invalid_argument(
+          "The NamedDevice could not be parsed as a valid protocol buffer");
+    }
+    tensorflow::grappler::OpLevelCostEstimator estimator;
+    tensorflow::grappler::DeviceInfo info =
+        estimator.GetDeviceInfo(device.properties());
+    return info.gigaops;
+  });
+
+  m.def("TF_MeasureCosts",
+        [](tensorflow::grappler::GrapplerItem* item,
+           tensorflow::grappler::Cluster* cluster, bool generate_timeline)
+            -> std::tuple<std::vector<py::bytes>, double, py::bytes> {
+          const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
+          tensorflow::grappler::MeasuringCostEstimator cost_measure(
+              cluster, num_measurements, 0);
+
+          tensorflow::OpPerformanceList op_performance_data;
+          tensorflow::grappler::Costs costs;
+          tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
+              *item, &cost_measure, &op_performance_data, &costs);
+          double run_time = FLT_MAX;
+          if (s.ok()) {
+            run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
+          }
+          tensorflow::StepStats step_stats;
+          if (generate_timeline) {
+            tensorflow::RunMetadata metadata;
+            MaybeRaiseRegisteredFromStatus(
+                cluster->Run(item->graph, item->feed, item->fetch, &metadata));
+            step_stats = metadata.step_stats();
+          }
+
+          std::vector<py::bytes> op_perf_objs;
+          op_perf_objs.resize(op_performance_data.op_performance_size());
+          for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
+            op_perf_objs[i] =
+                op_performance_data.op_performance(i).SerializeAsString();
+          }
+
+          py::bytes step_stats_str = step_stats.SerializeAsString();
+          return std::make_tuple(op_perf_objs, run_time, step_stats_str);
+        });
+
+  using DurationType = tensorflow::grappler::Costs::Duration::rep;
+  using MemoryUsage =
+      std::tuple<std::string, int, size_t, DurationType, DurationType>;
+
+  m.def(
+      "TF_DeterminePeakMemoryUsage",
+      [](tensorflow::grappler::GrapplerItem* item,
+         tensorflow::grappler::Cluster* cluster)
+          -> std::unordered_map<std::string,
+                                std::tuple<int64_t, std::vector<MemoryUsage>>> {
+        if (item == nullptr || cluster == nullptr) {
+          MaybeRaiseRegisteredFromStatus(
+              tensorflow::Status(tensorflow::errors::Internal(
+                  "You need both a cluster and an item to determine peak "
+                  "memory usage.")));
+        }
+        tensorflow::grappler::GraphMemory memory(*item);
+
+        if (cluster->DetailedStatsEnabled()) {
+          MaybeRaiseRegisteredFromStatus(memory.InferDynamically(cluster));
+        } else {
+          MaybeRaiseRegisteredFromStatus(
+              memory.InferStatically(cluster->GetDevices()));
+        }
+
+        std::unordered_map<std::string,
+                           std::tuple<int64_t, std::vector<MemoryUsage>>>
+            result;
+        for (const auto& device : cluster->GetDevices()) {
+          const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
+              memory.GetPeakMemoryUsage(device.first);
+          std::vector<MemoryUsage> per_device;
+          for (int i = 0; i < usage.live_tensors.size(); ++i) {
+            const auto& live_tensor = usage.live_tensors[i];
+            per_device.push_back(std::make_tuple(
+                live_tensor.node, live_tensor.output_id,
+                live_tensor.memory_used, live_tensor.allocation_time.count(),
+                live_tensor.deallocation_time.count()));
+          }
+          result[device.first] = std::make_tuple(usage.used_memory, per_device);
+        }
+        return result;
+      });
+}
diff --git a/tensorflow/python/grappler/cost_analyzer.i b/tensorflow/python/grappler/cost_analyzer.i
deleted file mode 100644
index 8f7fdb47f26..00000000000
--- a/tensorflow/python/grappler/cost_analyzer.i
+++ /dev/null
@@ -1,67 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "tensorflow/python/lib/core/strings.i"
-%include "tensorflow/python/platform/base.i"
-%include "cluster.i"
-
-%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
-  char* c_string;
-  Py_ssize_t py_size;
-  if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
-    // Python has raised an error (likely TypeError or UnicodeEncodeError).
-    SWIG_fail;
-  }
-
-  if (!temp.ParseFromString(string(c_string, py_size))) {
-    PyErr_SetString(
-        PyExc_TypeError,
-        "The MetaGraphDef could not be parsed as a valid protocol buffer");
-    SWIG_fail;
-  }
-  $1 = &temp;
-}
-
-%{
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/clusters/single_machine.h"
-#include "tensorflow/core/grappler/devices.h"
-#include "tensorflow/core/grappler/grappler_item_builder.h"
-#include "tensorflow/python/grappler/cost_analyzer.h"
-%}
-
-%{
-string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
-                          bool verbose, GCluster cluster) {
-  tensorflow::grappler::ItemConfig cfg;
-  cfg.apply_optimizations = false;
-  std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
-      tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg);
-  if (!item) {
-    return "Error: failed to preprocess metagraph: check your log file for errors";
-  }
-
-  string suffix;
-  tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix);
-
-  std::stringstream os;
-  analyzer.GenerateReport(os, per_node_report, verbose);
-  return os.str();
-}
-
-%}
-
-string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
-                          bool verbose, GCluster cluster);
diff --git a/tensorflow/python/grappler/cost_analyzer.py b/tensorflow/python/grappler/cost_analyzer.py
index 5cb9abe6386..a00a58e302f 100644
--- a/tensorflow/python/grappler/cost_analyzer.py
+++ b/tensorflow/python/grappler/cost_analyzer.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python import pywrap_tensorflow as tf_wrap
+from tensorflow.python import _pywrap_cost_analyzer as tf_wrap
 from tensorflow.python.grappler import cluster as gcluster
 from tensorflow.python.grappler import item as gitem
 
@@ -44,10 +44,9 @@ def GenerateCostReport(metagraph,
   if cluster is None:
     cluster = gcluster.Cluster(disable_detailed_stats=False)
 
-  ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
-                                             per_node_report, verbose,
-                                             cluster.tf_cluster)
-  return ret_from_swig
+  return tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
+                                    per_node_report, verbose,
+                                    cluster.tf_cluster)
 
 
 def GenerateMemoryReport(metagraph, detailed_report=True, cluster=None):
diff --git a/tensorflow/python/grappler/cost_analyzer_wrapper.cc b/tensorflow/python/grappler/cost_analyzer_wrapper.cc
new file mode 100644
index 00000000000..31fc0384a1b
--- /dev/null
+++ b/tensorflow/python/grappler/cost_analyzer_wrapper.cc
@@ -0,0 +1,58 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/core/grappler/clusters/single_machine.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/python/grappler/cost_analyzer.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(_pywrap_cost_analyzer, m) {
+  m.def("GenerateCostReport",
+        [](const py::bytes& serialized_metagraph, bool per_node_report,
+           bool verbose, tensorflow::grappler::Cluster* cluster) -> py::bytes {
+          tensorflow::MetaGraphDef metagraph;
+          if (!metagraph.ParseFromString(serialized_metagraph)) {
+            return "The MetaGraphDef could not be parsed as a valid protocol "
+                   "buffer";
+          }
+
+          tensorflow::grappler::ItemConfig cfg;
+          cfg.apply_optimizations = false;
+          std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
+              tensorflow::grappler::GrapplerItemFromMetaGraphDef(
+                  "metagraph", metagraph, cfg);
+          if (item == nullptr) {
+            return "Error: failed to preprocess metagraph: check your log file "
+                   "for errors";
+          }
+
+          std::string suffix;
+          tensorflow::grappler::CostAnalyzer analyzer(*item, cluster, suffix);
+
+          std::stringstream os;
+          tensorflow::MaybeRaiseFromStatus(
+              analyzer.GenerateReport(os, per_node_report, verbose));
+          return py::bytes(os.str());
+        });
+}
diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i
deleted file mode 100644
index 93228784b7b..00000000000
--- a/tensorflow/python/grappler/item.i
+++ /dev/null
@@ -1,315 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include <std_shared_ptr.i>
-%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
-  char* c_string;
-  Py_ssize_t py_size;
-  if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
-    // Python has raised an error (likely TypeError or UnicodeEncodeError).
-    SWIG_fail;
-  }
-
-  if (!temp.ParseFromString(string(c_string, py_size))) {
-    PyErr_SetString(
-        PyExc_TypeError,
-        "The MetaGraphDef could not be parsed as a valid protocol buffer");
-    SWIG_fail;
-  }
-  $1 = &temp;
-}
-
-// Wrap the item into an object that swig can manipulate. This ensures it will call the object
-// destructor upon garbage collection instead of leaking memory.
-struct GItem {
-  std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
-};
-
-
-%{
-#include <unordered_set>
-#include <map>
-#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
-#include "tensorflow/core/grappler/grappler_item_builder.h"
-#include "tensorflow/core/grappler/costs/graph_properties.h"
-#include "tensorflow/core/grappler/utils/topological_sort.h"
-#include "tensorflow/core/protobuf/error_codes.pb.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/protobuf/meta_graph.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-// Provide the implementation fo the GItem struct here.
-struct GItem {
-  GItem() {}
-  GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {}
-
-  tensorflow::grappler::GrapplerItem* operator->() const {
-    return item_.get();
-  }
-  const tensorflow::grappler::GrapplerItem& operator*() const {
-    return *item_.get();
-  }
-  bool is_none() const {
-    return item_.get() == nullptr;
-  }
-  std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
-};
-
-static GItem TF_NewItem(
-    const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
-    bool ignore_user_placement, TF_Status* status) {
-  if (meta_graph.collection_def().count("train_op") == 0) {
-    tensorflow::Set_TF_Status_from_Status(
-        status,
-        tensorflow::errors::InvalidArgument("train_op not specified in the metagraph"));
-    return nullptr;
-  }
-
-  tensorflow::grappler::ItemConfig cfg;
-  cfg.ignore_user_placement = ignore_user_placement;
-  cfg.ignore_colocation = ignore_colocation;
-  std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
-      tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg);
-  if (!item) {
-    tensorflow::Set_TF_Status_from_Status(
-        status,
-        tensorflow::errors::InvalidArgument("Invalid metagraph"));
-    return nullptr;
-  }
-  tensorflow::Set_TF_Status_from_Status(status, tensorflow::Status::OK());
-  return GItem(item.release());
-}
-
-static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
-                                                   TF_Status* status) {
-  if (item.is_none()) {
-    Py_RETURN_NONE;
-  }
-
-  std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin();
-  std::vector<const tensorflow::NodeDef*> enqueue_ops = item->EnqueueOpsFanin();
-  std::unordered_set<string> op_names;
-  for (auto op : main_ops) {
-    op_names.insert(op->name());
-  }
-  for (auto op : enqueue_ops) {
-    op_names.insert(op->name());
-  }
-
-  std::vector<string> ops;
-  if (sort_topologically) {
-    tensorflow::GraphDef subgraph;
-    for (const tensorflow::NodeDef& node : item->graph.node()) {
-      if (op_names.find(node.name()) != op_names.end()) {
-        *subgraph.add_node() = node;
-      }
-    }
-    tensorflow::Status s = tensorflow::grappler::TopologicalSort(&subgraph);
-    tensorflow::Set_TF_Status_from_Status(status, s);
-    for (const tensorflow::NodeDef& node : subgraph.node()) {
-      ops.push_back(node.name());
-    }
-  }
-  else {
-    for (const auto& op_name : op_names) {
-      ops.push_back(op_name);
-    }
-  }
-
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* result = PyList_New(ops.size());
-  for (int i = 0; i < ops.size(); ++i) {
-    PyList_SetItem(result, i, PyString_FromString(ops[i].c_str()));
-  }
-  PyGILState_Release(gstate);
-  return result;
-}
-
-static PyObject* TF_GetOpProperties(GItem item) {
-  if (item.is_none()) {
-    Py_RETURN_NONE;
-  }
-  tensorflow::grappler::GraphProperties properties(*item);
-  tensorflow::Status status = properties.InferStatically(false);
-  if (!status.ok()) {
-    Py_RETURN_NONE;
-  }
-
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* props = PyDict_New();
-  for (const auto& node : item->graph.node()) {
-    const string& node_name = node.name();
-    const std::vector<tensorflow::OpInfo::TensorProperties>& output_props =
-        properties.GetOutputProperties(node_name);
-
-    PyObject* prop = PyList_New(output_props.size());
-    for (int i = 0; i < output_props.size(); ++i) {
-      string output_prop_str = output_props[i].SerializeAsString();
-      PyObject* output_prop = PyBytes_FromStringAndSize(output_prop_str.data(),
-                                                        output_prop_str.size());
-      PyList_SetItem(prop, i, output_prop);
-    }
-    CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop));
-  }
-  PyGILState_Release(gstate);
-  return props;
-}
-
-class ColocationGroups {
-public:
-  void Group(const string& x, const string& y) {
-    Rep* x_root = Find(x);
-    Rep* y_root = Find(y);
-
-    // x and y are already in the same set
-    if (x_root == y_root) {
-      return;
-    }
-    // x and y are not in same set, so we merge them
-    // Use the occasion to strengthen what we know about the handle by merging the
-    // information about the 2 subsets.
-    if (x_root->rank < y_root->rank) {
-      x_root->parent = y_root;
-    } else if (x_root->rank > y_root->rank) {
-      y_root->parent = x_root;
-    } else {
-      // Arbitrarily make one root the new parent
-      y_root->parent = x_root;
-      x_root->rank = x_root->rank + 1;
-    }
-  }
-
-  void ExtractGroups(std::vector<std::vector<string>>* groups) {
-    groups->reserve(nodes_.size());
-    std::unordered_map<const Rep*, int> group_ids;
-    for (const auto& rep : nodes_) {
-      Rep* r = Find(rep.first);
-      auto it = group_ids.find(r);
-      std::vector<string>* g;
-      if (it == group_ids.end()) {
-        int id = group_ids.size();
-        group_ids[r] = id;
-        groups->resize(id+1);
-        g = &groups->back();
-      } else {
-        int id = it->second;
-        g = &((*groups)[id]);
-      }
-      g->push_back(rep.first);
-    }
-  }
-
-private:
-  struct Rep {
-    // Parent in the tree used to encode the set.
-    Rep* parent;
-    // Rank in the tree, used to figure out how to compress the path to the root
-    // of the tree.
-    int rank;
-    // The node.
-    string value;
-  };
-
-  Rep* Find(const string& n) {
-    auto it = nodes_.find(n);
-    if (it == nodes_.end()) {
-      // This is the first time we process this handle, create an entry for it.
-      Rep* node = new Rep;
-      node->parent = node;
-      node->rank = 0;
-      node->value = n;
-      nodes_[n] = node;
-      return node;
-    }
-    // Return the representative for the set, which is the root of the tree. Apply
-    // path compression to speedup future queries.
-    Rep* node = it->second;
-    Rep* root = node->parent;
-    while (root != root->parent) {
-      root = root->parent;
-    }
-    while (node->parent != root) {
-      Rep* next = node->parent;
-      node->parent = root;
-      node = next;
-    }
-    return root;
-  }
-
-  std::unordered_map<string, Rep*> nodes_;
-};
-
-static PyObject* TF_GetColocationGroups(GItem item) {
-  if (item.is_none()) {
-    Py_RETURN_NONE;
-  }
-  ColocationGroups groupings;
-  tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
-  for (const auto& node : item->graph.node()) {
-    const tensorflow::OpDef* op_def;
-    tensorflow::Status s = registry->LookUpOpDef(node.op(), &op_def);
-    if (!s.ok()) {
-      continue;
-    }
-    tensorflow::NameRangeMap inputs;
-    tensorflow::NameRangeMap outputs;
-    s = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs);
-    if (!s.ok()) {
-      continue;
-    }
-    for (const auto& arg : op_def->input_arg()) {
-      if (!arg.is_ref()) {
-        continue;
-      }
-      const auto& range = inputs[arg.name()];
-      for (int i = range.first; i < range.second; ++i) {
-        string input = tensorflow::grappler::NodeName(node.input(i));
-        groupings.Group(node.name(), input);
-      }
-    }
-  }
-
-  std::vector<std::vector<string>> groups;
-  groupings.ExtractGroups(&groups);
-
-  PyGILState_STATE gstate = PyGILState_Ensure();
-  PyObject* result = PyList_New(groups.size());
-  for (int i = 0; i < groups.size(); ++i) {
-    const std::vector<string>& group = groups[i];
-    PyObject* g = PyTuple_New(group.size());
-    for (int j = 0; j < group.size(); ++j) {
-      const string& node_name = group[j];
-      PyTuple_SetItem(g, j, PyString_FromString(node_name.c_str()));
-    }
-    PyList_SetItem(result, i, g);
-  }
-  PyGILState_Release(gstate);
-  return result;
-}
-
-%}
-
-
-// Wrap these functions.
-static GItem TF_NewItem(
-    const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
-    bool ignore_user_placement, TF_Status* status);
-static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
-                                         TF_Status* status);
-static PyObject* TF_GetOpProperties(GItem item);
-static PyObject* TF_GetColocationGroups(GItem item);
diff --git a/tensorflow/python/grappler/item.py b/tensorflow/python/grappler/item.py
index 2578f017c26..c4495ac1e34 100644
--- a/tensorflow/python/grappler/item.py
+++ b/tensorflow/python/grappler/item.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 from tensorflow.core.grappler.costs import op_performance_data_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python import pywrap_tensorflow as tf_item
+from tensorflow.python import _pywrap_tf_item as tf_item
 
 
 class Item(object):
@@ -53,11 +53,14 @@ class Item(object):
     return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically)
 
   def GetOpProperties(self):
-    ret_from_swig = tf_item.TF_GetOpProperties(self.tf_item)
+    """Get Op properties."""
+    props = tf_item.TF_GetOpProperties(self.tf_item)
     properties = {}
-    for key, values in ret_from_swig.items():
+    for key, values in props.items():
       prop = []
       for value in values:
+        # TODO(petebu): Make this conversion to a dictionary be done in the C++
+        # wrapper for performance.
         prop.append(
             op_performance_data_pb2.OpInfo.TensorProperties.FromString(value))
       properties[key] = prop
diff --git a/tensorflow/python/grappler/item_wrapper.cc b/tensorflow/python/grappler/item_wrapper.cc
new file mode 100644
index 00000000000..d1c50f4e21a
--- /dev/null
+++ b/tensorflow/python/grappler/item_wrapper.cc
@@ -0,0 +1,245 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "include/pybind11/pybind11.h"
+#include "include/pybind11/stl.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+namespace py = pybind11;
+
+class ColocationGroups {
+ public:
+  void Group(const std::string& x, const std::string& y) {
+    Rep* x_root = Find(x);
+    Rep* y_root = Find(y);
+
+    // x and y are already in the same set
+    if (x_root == y_root) {
+      return;
+    }
+    // x and y are not in same set, so we merge them
+    // Use the occasion to strengthen what we know about the handle by merging
+    // the information about the 2 subsets.
+    if (x_root->rank < y_root->rank) {
+      x_root->parent = y_root;
+    } else if (x_root->rank > y_root->rank) {
+      y_root->parent = x_root;
+    } else {
+      // Arbitrarily make one root the new parent
+      y_root->parent = x_root;
+      x_root->rank = x_root->rank + 1;
+    }
+  }
+
+  void ExtractGroups(std::vector<std::vector<std::string>>* groups) {
+    groups->reserve(nodes_.size());
+    std::unordered_map<const Rep*, int> group_ids;
+    for (const auto& rep : nodes_) {
+      Rep* r = Find(rep.first);
+      auto it = group_ids.find(r);
+      std::vector<std::string>* g;
+      if (it == group_ids.end()) {
+        int id = group_ids.size();
+        group_ids[r] = id;
+        groups->resize(id + 1);
+        g = &groups->back();
+      } else {
+        int id = it->second;
+        g = &((*groups)[id]);
+      }
+      g->push_back(rep.first);
+    }
+  }
+
+ private:
+  struct Rep {
+    // Parent in the tree used to encode the set.
+    Rep* parent;
+    // Rank in the tree, used to figure out how to compress the path to the root
+    // of the tree.
+    int rank;
+    // The node.
+    std::string value;
+  };
+
+  Rep* Find(const std::string& n) {
+    auto it = nodes_.find(n);
+    if (it == nodes_.end()) {
+      // This is the first time we process this handle, create an entry for it.
+      Rep* node = new Rep;
+      node->parent = node;
+      node->rank = 0;
+      node->value = n;
+      nodes_[n] = node;
+      return node;
+    }
+    // Return the representative for the set, which is the root of the tree.
+    // Apply path compression to speedup future queries.
+    Rep* node = it->second;
+    Rep* root = node->parent;
+    while (root != root->parent) {
+      root = root->parent;
+    }
+    while (node->parent != root) {
+      Rep* next = node->parent;
+      node->parent = root;
+      node = next;
+    }
+    return root;
+  }
+
+  std::unordered_map<std::string, Rep*> nodes_;
+};
+
+PYBIND11_MAKE_OPAQUE(tensorflow::grappler::GrapplerItem);
+
+PYBIND11_MODULE(_pywrap_tf_item, m) {
+  py::class_<tensorflow::grappler::GrapplerItem> grappler_item(
+      m, "tensorflow::grappler::GrapplerItem");
+
+  m.def("TF_NewItem",
+        [](const py::bytes& serialized_metagraph, bool ignore_colocation,
+           bool ignore_user_placement) -> tensorflow::grappler::GrapplerItem* {
+          tensorflow::MetaGraphDef metagraph;
+          if (!metagraph.ParseFromString(serialized_metagraph)) {
+            throw std::invalid_argument(
+                "The MetaGraphDef could not be parsed as a valid protocol "
+                "buffer");
+          }
+          if (metagraph.collection_def().count("train_op") == 0) {
+            MaybeRaiseRegisteredFromStatus(tensorflow::errors::InvalidArgument(
+                "train_op not specified in the metagraph"));
+          }
+
+          tensorflow::grappler::ItemConfig cfg;
+          cfg.ignore_user_placement = ignore_user_placement;
+          cfg.ignore_colocation = ignore_colocation;
+          std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
+              tensorflow::grappler::GrapplerItemFromMetaGraphDef(
+                  "item", metagraph, cfg);
+          if (item == nullptr) {
+            MaybeRaiseRegisteredFromStatus(
+                tensorflow::errors::InvalidArgument("Invalid metagraph"));
+          }
+          return item.release();
+        });
+
+  m.def("TF_IdentifyImportantOps",
+        [](tensorflow::grappler::GrapplerItem* item,
+           bool sort_topologically) -> std::vector<std::string> {
+          std::vector<const tensorflow::NodeDef*> main_ops =
+              item->MainOpsFanin();
+          std::vector<const tensorflow::NodeDef*> enqueue_ops =
+              item->EnqueueOpsFanin();
+          std::unordered_set<std::string> op_names;
+          for (auto op : main_ops) {
+            op_names.insert(op->name());
+          }
+          for (auto op : enqueue_ops) {
+            op_names.insert(op->name());
+          }
+
+          std::vector<std::string> ops;
+          if (sort_topologically) {
+            tensorflow::GraphDef subgraph;
+            for (const tensorflow::NodeDef& node : item->graph.node()) {
+              if (op_names.find(node.name()) != op_names.end()) {
+                *subgraph.add_node() = node;
+              }
+            }
+            tensorflow::MaybeRaiseFromStatus(
+                tensorflow::grappler::TopologicalSort(&subgraph));
+            for (const tensorflow::NodeDef& node : subgraph.node()) {
+              ops.push_back(node.name());
+            }
+          } else {
+            for (const auto& op_name : op_names) {
+              ops.push_back(op_name);
+            }
+          }
+          return ops;
+        });
+
+  m.def("TF_GetOpProperties",
+        [](tensorflow::grappler::GrapplerItem* item)
+            -> std::unordered_map<std::string, std::vector<py::bytes>> {
+          tensorflow::grappler::GraphProperties properties(*item);
+          tensorflow::MaybeRaiseFromStatus(properties.InferStatically(false));
+
+          std::unordered_map<std::string, std::vector<py::bytes>> props;
+          for (const auto& node : item->graph.node()) {
+            const std::string& node_name = node.name();
+            const std::vector<tensorflow::OpInfo::TensorProperties>&
+                output_props = properties.GetOutputProperties(node_name);
+
+            std::vector<py::bytes> prop;
+            prop.reserve(output_props.size());
+            for (const auto& output_prop : output_props) {
+              prop.push_back(output_prop.SerializeAsString());
+            }
+            props[node_name] = prop;
+          }
+          return props;
+        });
+
+  m.def("TF_GetColocationGroups",
+        [](tensorflow::grappler::GrapplerItem* item)
+            -> std::vector<std::vector<std::string>> {
+          ColocationGroups groupings;
+          tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
+          for (const auto& node : item->graph.node()) {
+            const tensorflow::OpDef* op_def;
+            if (!registry->LookUpOpDef(node.op(), &op_def).ok()) {
+              continue;
+            }
+            tensorflow::NameRangeMap inputs;
+            tensorflow::NameRangeMap outputs;
+            if (!tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs)
+                     .ok()) {
+              continue;
+            }
+            for (const auto& arg : op_def->input_arg()) {
+              if (!arg.is_ref()) {
+                continue;
+              }
+              const auto& range = inputs[arg.name()];
+              for (int i = range.first; i < range.second; ++i) {
+                groupings.Group(node.name(),
+                                tensorflow::grappler::NodeName(node.input(i)));
+              }
+            }
+          }
+
+          std::vector<std::vector<std::string>> groups;
+          groupings.ExtractGroups(&groups);
+          return groups;
+        });
+}
diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i
deleted file mode 100644
index d8ba0409eb6..00000000000
--- a/tensorflow/python/grappler/tf_optimizer.i
+++ /dev/null
@@ -1,144 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-
-%include "tensorflow/python/platform/base.i"
-%include "cluster.i"
-
-%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
-  char* c_string;
-  Py_ssize_t py_size;
-  if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
-    // Python has raised an error (likely TypeError or UnicodeEncodeError).
-    SWIG_fail;
-  }
-
-  if (!temp.ParseFromString(string(c_string, py_size))) {
-    PyErr_SetString(
-        PyExc_TypeError,
-        "The MetaGraphDef could not be parsed as a valid protocol buffer");
-    SWIG_fail;
-  }
-  $1 = &temp;
-}
-
-%typemap(in) const tensorflow::ConfigProto& (
-    tensorflow::ConfigProto temp) {
-  char* c_string;
-  Py_ssize_t py_size;
-  if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
-    // Python has raised an error (likely TypeError or UnicodeEncodeError).
-    SWIG_fail;
-  }
-
-  if (!temp.ParseFromString(string(c_string, py_size))) {
-    PyErr_SetString(
-        PyExc_TypeError,
-        "The ConfigProto could not be parsed as a valid protocol buffer");
-    SWIG_fail;
-  }
-  $1 = &temp;
-}
-
-%{
-  #include <memory>
-  #include "tensorflow/c/tf_status_helper.h"
-  #include "tensorflow/core/lib/core/status.h"
-  #include "tensorflow/core/common_runtime/device.h"
-  #include "tensorflow/core/framework/device_base.h"
-  #include "tensorflow/core/common_runtime/device_factory.h"
-  #include "tensorflow/core/framework/device_attributes.pb.h"
-  #include "tensorflow/core/framework/graph_def_util.h"
-  #include "tensorflow/core/framework/graph.pb.h"
-  #include "tensorflow/core/grappler/grappler_item.h"
-  #include "tensorflow/core/grappler/grappler_item_builder.h"
-  #include "tensorflow/core/grappler/clusters/cluster.h"
-  #include "tensorflow/core/grappler/clusters/utils.h"
-  #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
-  #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
-  #include "tensorflow/core/protobuf/config.pb.h"
-  #include "tensorflow/core/protobuf/meta_graph.pb.h"
-  #include "tensorflow/core/public/session_options.h"
-
-
-void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
-  tensorflow::SessionOptions options;
-  std::vector<std::unique_ptr<tensorflow::Device>> devices;
-  tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
-  if (!status.ok()) {
-    return;
-  }
-
-  for (const std::unique_ptr<tensorflow::Device>& device : devices) {
-    tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
-    prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
-
-    // Overwrite the memory limit since users might have requested to use only a fraction of the
-    // available device memory.
-    const tensorflow::DeviceAttributes& attr = device->attributes();
-    prop.set_memory_size(attr.memory_limit());
-  }
-}
-
-PyObject* TF_OptimizeGraph(
-      GCluster cluster,
-      const tensorflow::ConfigProto& config_proto,
-      const tensorflow::MetaGraphDef& metagraph,
-      bool verbose, const string& graph_id,
-      bool strip_default_attributes,
-      TF_Status* status) {
-    tensorflow::grappler::ItemConfig item_config;
-    item_config.apply_optimizations = false;
-    item_config.ignore_user_placement = false;
-    std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
-        tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
-
-    if (!grappler_item) {
-      TF_SetStatus(status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
-      return nullptr;
-    }
-
-    tensorflow::DeviceBase* cpu_device = nullptr;
-    tensorflow::GraphDef out_graph;
-    tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
-    tensorflow::Status s = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
-    tensorflow::Set_TF_Status_from_Status(status, s);
-    if (!s.ok()) {
-       Py_RETURN_NONE;
-    }
-    if (strip_default_attributes) {
-      tensorflow::StripDefaultAttributes(*tensorflow::OpRegistry::Global(),
-                                         out_graph.mutable_node());
-    }
-    if (verbose) {
-      optimizer.PrintResult();
-    }
-    string out_graph_str = out_graph.SerializeAsString();
-    PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
-                                              out_graph_str.size());
-    return ret;
-  }
-%}
-
-
-// Wrap this function
-PyObject* TF_OptimizeGraph(
-    GCluster cluster,
-    const tensorflow::ConfigProto& config_proto,
-    const tensorflow::MetaGraphDef& metagraph, bool verbose,
-    const string& graph_id, bool strip_default_attributes, TF_Status* status);
-
-
-
diff --git a/tensorflow/python/grappler/tf_optimizer.py b/tensorflow/python/grappler/tf_optimizer.py
index 5196a9d53b7..bea9acb573b 100644
--- a/tensorflow/python/grappler/tf_optimizer.py
+++ b/tensorflow/python/grappler/tf_optimizer.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.python import pywrap_tensorflow as tf_opt
+from tensorflow.python import _pywrap_tf_optimizer as tf_opt
 from tensorflow.python.grappler import cluster as gcluster
 
 
@@ -52,12 +52,8 @@ def OptimizeGraph(config_proto,
                     type(config_proto))
   if cluster is None:
     cluster = gcluster.Cluster()
-  ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
-                                          config_proto.SerializeToString(),
-                                          metagraph.SerializeToString(),
-                                          verbose, graph_id,
-                                          strip_default_attributes)
-  if ret_from_swig is None:
-    return None
-  out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)
-  return out_graph
+  out_graph = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
+                                      config_proto.SerializeToString(),
+                                      metagraph.SerializeToString(), verbose,
+                                      graph_id, strip_default_attributes)
+  return graph_pb2.GraphDef().FromString(out_graph)
diff --git a/tensorflow/python/grappler/tf_optimizer_wrapper.cc b/tensorflow/python/grappler/tf_optimizer_wrapper.cc
new file mode 100644
index 00000000000..91aeae473c0
--- /dev/null
+++ b/tensorflow/python/grappler/tf_optimizer_wrapper.cc
@@ -0,0 +1,108 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <unordered_map>
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/device_properties.pb.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+namespace py = pybind11;
+
+void DetectDevices(
+    std::unordered_map<std::string, tensorflow::DeviceProperties>* device_map) {
+  tensorflow::SessionOptions options;
+  std::vector<std::unique_ptr<tensorflow::Device>> devices;
+  if (!tensorflow::DeviceFactory::AddDevices(options, "", &devices).ok()) {
+    return;
+  }
+
+  for (const std::unique_ptr<tensorflow::Device>& device : devices) {
+    tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
+    prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
+
+    // Overwrite the memory limit since users might have requested to use only a
+    // fraction of the available device memory.
+    const tensorflow::DeviceAttributes& attr = device->attributes();
+    prop.set_memory_size(attr.memory_limit());
+  }
+}
+
+PYBIND11_MODULE(_pywrap_tf_optimizer, m) {
+  m.def(
+      "TF_OptimizeGraph",
+      [](tensorflow::grappler::Cluster* cluster,
+         const py::bytes& serialized_config_proto,
+         const py::bytes& serialized_metagraph, bool verbose,
+         const std::string& graph_id,
+         bool strip_default_attributes) -> py::bytes {
+        tensorflow::ConfigProto config_proto;
+        if (!config_proto.ParseFromString(serialized_config_proto)) {
+          throw std::invalid_argument(
+              "The ConfigProto could not be parsed as a valid protocol buffer");
+        }
+        tensorflow::MetaGraphDef metagraph;
+        if (!metagraph.ParseFromString(serialized_metagraph)) {
+          throw std::invalid_argument(
+              "The MetaGraphDef could not be parsed as a valid protocol "
+              "buffer");
+        }
+
+        tensorflow::grappler::ItemConfig item_config;
+        // This disables graph optimizations in the older graph optimizer, which
+        // tend to overlap / be redundant with those in Grappler.
+        item_config.apply_optimizations = false;
+        item_config.ignore_user_placement = false;
+        std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
+            tensorflow::grappler::GrapplerItemFromMetaGraphDef(
+                graph_id, metagraph, item_config);
+        if (!grappler_item) {
+          throw std::invalid_argument(
+              "Failed to import metagraph, check error log for more info.");
+        }
+
+        tensorflow::DeviceBase* cpu_device = nullptr;
+        tensorflow::GraphDef out_graph;
+        tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
+
+        MaybeRaiseRegisteredFromStatus(
+            optimizer.Optimize(cluster, *grappler_item, &out_graph));
+        if (strip_default_attributes) {
+          tensorflow::StripDefaultAttributes(*tensorflow::OpRegistry::Global(),
+                                             out_graph.mutable_node());
+        }
+        if (verbose) {
+          optimizer.PrintResult();
+        }
+        return out_graph.SerializeAsString();
+      });
+}
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 4f6bb3f9efd..12beb907982 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -17,10 +17,11 @@ limitations under the License.
  * The includes are intentionally not alphabetically sorted, as the order of
  * includes follows dependency order */
 
-%include "tensorflow/python/grappler/cluster.i"
-%include "tensorflow/python/grappler/item.i"
-%include "tensorflow/python/grappler/tf_optimizer.i"
-%include "tensorflow/python/grappler/cost_analyzer.i"
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/core/lib/core/status.h"
+%}
 
 // TODO(slebedev): This is a temporary workaround for projects implicitly
 // relying on TensorFlow exposing tensorflow::Status.
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 2fccb319d28..c863be99c8a 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -243,3 +243,59 @@ tensorflow::TF_GraphSetTensorShape_wrapper
 tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper
 tensorflow::TF_TryEvaluateConstant_wrapper
 
+[grappler_item] # tf_item
+tensorflow::grappler::GrapplerItem::MainOpsFanin
+tensorflow::grappler::GrapplerItem::EnqueueOpsFanin
+
+[graph_properties] # tf_item
+tensorflow::grappler::GraphProperties::InferStatically
+tensorflow::grappler::GraphProperties::GetOutputProperties
+
+[grappler_item_builder] # tf_item
+tensorflow::grappler::GrapplerItemFromMetaGraphDef
+
+[topological_sort] # tf_item
+tensorflow::grappler::TopologicalSort
+
+[clusters/utils] # tf_cluster tf_optimizer
+tensorflow::grappler::GetDeviceInfo
+
+[costs/utils] # tf_optimizer tf_cluster
+tensorflow::grappler::CostGraphToOpPerformanceData
+tensorflow::grappler::GetDeviceInfo
+
+[meta_optimizer] # tf_optimizer
+tensorflow::grappler::MetaOptimizer::MetaOptimizer
+tensorflow::grappler::MetaOptimizer::Optimize
+tensorflow::grappler::MetaOptimizer::PrintResult
+
+[clusters/cluster] # tf_cluster
+tensorflow::grappler::Cluster::AllowSoftPlacement
+tensorflow::grappler::Cluster::SetNumWarmupSteps
+tensorflow::grappler::Cluster::DisableDetailedStats
+tensorflow::grappler::Cluster::DetailedStatsEnabled
+
+[single_machine] # tf_cluster
+tensorflow::grappler::SingleMachine::SingleMachine
+
+[op_level_cost_estimator] # tf_cluster
+tensorflow::grappler::OpLevelCostEstimator::OpLevelCostEstimator
+tensorflow::grappler::OpLevelCostEstimator::PredictCosts
+tensorflow::grappler::OpLevelCostEstimator::GetDeviceInfo
+
+[virtual_cluster] # tf_cluster
+tensorflow::grappler::VirtualCluster::VirtualCluster
+
+[graph_memory] # tf_cluster
+tensorflow::grappler::GraphMemory::InferStatically
+tensorflow::grappler::GraphMemory::InferDynamically
+
+[measuring_cost_estimator] # tf_cluster
+tensorflow::grappler::MeasuringCostEstimator::MeasuringCostEstimator
+tensorflow::grappler::MeasuringCostEstimator::Initialize
+tensorflow::grappler::MeasuringCostEstimator::PredictCosts
+
+[devices] # tf_cluster
+tensorflow::grappler::GetNumAvailableGPUs
+tensorflow::grappler::GetNumAvailableLogicalCPUCores
+