From 3ac0818dea42c6699cda3cea26657f606f11a3cc Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Mon, 26 Oct 2020 17:42:57 -0700
Subject: [PATCH] Add PythonTensorConverter class, which can be used in c++ to
 efficiently convert PyObjects to tensors.

PiperOrigin-RevId: 339155091
Change-Id: Icad20253f523e3ac3685d8d34c52e940f3889b57
---
 tensorflow/c/eager/BUILD                      |   1 +
 tensorflow/python/BUILD                       |  75 +++++++
 tensorflow/python/eager/BUILD                 |   1 +
 .../framework/python_tensor_converter.cc      | 136 ++++++++++++
 .../framework/python_tensor_converter.h       |  76 +++++++
 .../framework/python_tensor_converter_test.py | 208 ++++++++++++++++++
 .../python_tensor_converter_wrapper.cc        | 120 ++++++++++
 .../tools/def_file_filter/symbols_pybind.txt  |   3 +
 8 files changed, 620 insertions(+)
 create mode 100644 tensorflow/python/framework/python_tensor_converter.cc
 create mode 100644 tensorflow/python/framework/python_tensor_converter.h
 create mode 100644 tensorflow/python/framework/python_tensor_converter_test.py
 create mode 100644 tensorflow/python/framework/python_tensor_converter_wrapper.cc

diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index c44d0ee6873..fa0fdbae861 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -110,6 +110,7 @@ filegroup(
         "abstract_function.h",
         "abstract_operation.h",
         "abstract_tensor_handle.h",
+        "c_api.h",
         "c_api_experimental.h",
         "c_api_internal.h",
         "c_api_unified_experimental.h",
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 89c3719b943..903c2449715 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1777,6 +1777,79 @@ tf_py_test(
     ],
 )
 
+cc_library(
+    name = "python_tensor_converter",
+    srcs = ["framework/python_tensor_converter.cc"],
+    hdrs = ["framework/python_tensor_converter.h"],
+    deps = [
+        ":cpp_python_util",
+        ":safe_pyobject_ptr",
+        "//tensorflow/c/eager:c_api",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/python/eager:pywrap_tfe_lib",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "@com_google_absl//absl/strings",
+    ],
+)
+
+# Note: this target is only used by python_tensor_converter_test.
+tf_python_pybind_extension(
+    name = "_pywrap_python_tensor_converter",
+    srcs = ["framework/python_tensor_converter_wrapper.cc"],
+    hdrs = [
+        "framework/python_tensor_converter.h",
+        "lib/core/numpy.h",
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:pywrap_required_hdrs",
+        "//tensorflow/c/experimental/ops:pywrap_required_hdrs",
+        "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
+        "//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
+        "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
+        "//tensorflow/python/eager:pywrap_required_hdrs",
+    ],
+    module_name = "_pywrap_python_tensor_converter",
+    deps = [
+        ":safe_pyobject_ptr_required_hdrs",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/hash",
+        "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+        "@pybind11",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/common_runtime:core_cpu_headers_lib",
+        "//tensorflow/core:lib_headers_for_pybind",
+        "//third_party/py/numpy:headers",
+        "//tensorflow/c:pywrap_required_hdrs",
+        "@com_google_absl//absl/types:span",
+    ] + if_static(
+        extra_deps = [
+            "//tensorflow/core/protobuf:eager_service_proto_cc",
+            "//tensorflow/core/protobuf:master_proto_cc",
+            "//tensorflow/core/protobuf:worker_proto_cc",
+        ],
+        otherwise = [
+            "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:master_proto_cc_headers_only",
+            "//tensorflow/core/protobuf:worker_proto_cc_headers_only",
+        ],
+    ),
+)
+
+tf_py_test(
+    name = "python_tensor_converter_test",
+    srcs = ["framework/python_tensor_converter_test.py"],
+    python_version = "PY3",
+    tags = ["no_pip"],
+    deps = [
+        ":_pywrap_python_tensor_converter",
+        ":client_testlib",
+    ],
+)
+
 py_library(
     name = "framework_ops",  # "ops" is already the name of a deprecated target
     srcs = ["framework/ops.py"],
@@ -6026,6 +6099,7 @@ pywrap_tensorflow_macro(
         ":pybind11_proto",
         ":python_api_dispatcher",
         ":python_op_gen",
+        ":python_tensor_converter",
         ":safe_pyobject_ptr",
         ":tf_session_helper",
         "//third_party/python_runtime:headers",
@@ -6096,6 +6170,7 @@ filegroup(
         ":py_exception_registry",  # py_exception_registry
         ":py_func_lib",  # py_func
         ":python_api_dispatcher",  # python_api_dispatcher
+        ":python_tensor_converter",  # python_tensor_converter
         ":python_op_gen",  # python_op_gen
         ":safe_ptr",  # checkpoint_reader
         "//tensorflow/c:checkpoint_reader",  # checkpoint_reader
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index fed430a2f8e..789e9419d9e 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -72,6 +72,7 @@ cc_library(
         "//tensorflow/python:py_seq_tensor",
         "//tensorflow/python:py_util",
         "//tensorflow/python:safe_ptr",
+        "//tensorflow/python:safe_pyobject_ptr",
         "//tensorflow/python:stack_trace",
         "//third_party/py/numpy:headers",
         "//third_party/python_runtime:headers",
diff --git a/tensorflow/python/framework/python_tensor_converter.cc b/tensorflow/python/framework/python_tensor_converter.cc
new file mode 100644
index 00000000000..f18c8a8c681
--- /dev/null
+++ b/tensorflow/python/framework/python_tensor_converter.cc
@@ -0,0 +1,136 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/framework/python_tensor_converter.h"
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/python/eager/pywrap_tensor.h"
+#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "tensorflow/python/util/util.h"
+
+#if PY_MAJOR_VERSION < 3
+// Python 2.x:
+#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
+#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
+#else
+// Python 3.x:
+#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
+#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
+#endif
+
+namespace tensorflow {
+namespace {
+
+// Returns `tensor.dtype._type_enum` as a DataType enum.  Assumes that `tensor`
+// is a python `Tensor` object.
+//
+// On error: sets a python AttributeError exception and returns DT_INVALID.
+DataType DataTypeForTensor(PyObject* tensor) {
+  static PyObject* dtype_attr = PY_STRING_INTERN_FROM_STRING("dtype");
+  static PyObject* type_enum_attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
+
+  Safe_PyObjectPtr py_dtype(PyObject_GetAttr(tensor, dtype_attr));
+  if (!py_dtype) return DT_INVALID;
+
+  Safe_PyObjectPtr enum_field(PyObject_GetAttr(py_dtype.get(), type_enum_attr));
+  if (!enum_field) return DT_INVALID;
+
+  DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get()));
+  return result;
+}
+
+// Check that actual_dtype == expected_dtype.  If not, set an exception and
+// return false.  (If expected_dtype is DT_INVALID, then instead simply update
+// its value to `actual_dtype` and return true.)
+bool CheckDType(DataType actual_dtype, DataType& expected_dtype) {
+  if (expected_dtype == DT_INVALID) {
+    expected_dtype = actual_dtype;  // set output parameter.
+  } else if (expected_dtype != actual_dtype) {
+    PyErr_SetString(PyExc_TypeError,
+                    absl::StrCat("Expected ", DataType_Name(expected_dtype),
+                                 " but got ", DataType_Name(actual_dtype))
+                        .c_str());
+    return false;
+  }
+  return true;
+}
+
+}  // namespace
+
+Safe_PyObjectPtr PythonTensorConverter::Convert(PyObject* src, DataType& dtype,
+                                                bool* used_fallback) const {
+  // First, try converting `src` to a Tensor without calling back into Python.
+  if (ctx_) {  // Eager mode
+    // TODO(b/164980194): Handle resource variables as well.  (See
+    // ConvertToTensor function in pywrap_tfe_src.cc).
+    if (EagerTensor_CheckExact(src)) {
+      // `src` is already an eager tensor; check its type, and return it as-is.
+      if (!CheckDType(PyEagerTensor_Dtype(src), dtype)) return nullptr;
+      Py_INCREF(src);
+      return Safe_PyObjectPtr(src);
+    } else {
+      TFE_TensorHandle* handle =
+          tensorflow::ConvertToEagerTensor(ctx_, src, dtype, device_name_);
+      if (handle) {
+        Safe_PyObjectPtr result(EagerTensorFromHandle(handle));
+        if (!CheckDType(PyEagerTensor_Dtype(result.get()), dtype)) {
+          return nullptr;
+        }
+        return result;
+      } else {
+        PyErr_Clear();
+      }
+    }
+  } else {  // Graph mode
+    if (swig::IsTensor(src)) {
+      DataType src_dtype = DataTypeForTensor(src);
+      if (src_dtype == DT_INVALID) return nullptr;
+      if (!CheckDType(src_dtype, dtype)) return nullptr;
+      Py_INCREF(src);
+      return Safe_PyObjectPtr(src);
+    }
+  }
+
+  // Fallback: use the Python tf.convert_to_tensor function.
+  // Currently this is used:
+  //
+  // * In Eager mode: for anything that's not already an Eager tensor, or
+  //   handled by `tensorflow::ConvertToEagerTensor`.  (At time of writing
+  //   for this comment, ConvertToEagerTensor handles simple values like ints,
+  //   nested lists of simple values, and numpy arrays.)
+  // * In graph mode: for anything that's not already a tensor.
+  //
+  // TODO(b/164980194) Reduce/eliminate cases where fallback is used.
+  if (used_fallback) *used_fallback = true;
+  static PyObject* convert_to_tensor =
+      swig::GetRegisteredPyObject("tf.convert_to_tensor");
+  if (!convert_to_tensor) return nullptr;
+
+  Safe_PyObjectPtr args(PyTuple_New(dtype == DT_INVALID ? 1 : 2));
+  Safe_PyObjectPtr kwargs(PyDict_New());
+  Py_INCREF(src);
+  PyTuple_SetItem(args.get(), 0, src);
+  if (dtype != DT_INVALID) {
+    PyTuple_SetItem(args.get(), 1, PyLong_FromLong(dtype));
+  }
+  PyDict_SetItemString(kwargs.get(), "ctx", py_eager_context_);
+  Safe_PyObjectPtr result(
+      PyObject_Call(convert_to_tensor, args.get(), kwargs.get()));
+  if (!result) return nullptr;
+  dtype = DataTypeForTensor(result.get());  // set output parameter.
+  if (dtype == DT_INVALID) return nullptr;
+  return result;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/python/framework/python_tensor_converter.h b/tensorflow/python/framework/python_tensor_converter.h
new file mode 100644
index 00000000000..faf1793d4cd
--- /dev/null
+++ b/tensorflow/python/framework/python_tensor_converter.h
@@ -0,0 +1,76 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
+#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
+
+#include <Python.h>
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
+
+namespace tensorflow {
+
+// Converts PyObject* values to Tensors.
+//
+// This converter attempts to convert values as efficiently as possible; but
+// it has fallback paths to handle any PyObject* value for which tensor
+// conversion is defined.
+class PythonTensorConverter {
+ public:
+  // Constructs a new PythonTensorConverter.
+  //
+  // Note: the arguments to this constructor may change in the future, as
+  // we move more of python tensor conversion from the Python layer to the
+  // c++ layer.
+  //
+  // Args:
+  //   py_eager_context: the value of context.context() from eager/context.py.
+  //   ctx: The c++ eager context, or nullptr in graph mode.
+  //   device_name: The current device name.
+  //
+  // All three argument values must remain alive until `this` is deleted.
+  PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx,
+                        const char* device_name)
+      : py_eager_context_(py_eager_context),
+        ctx_(ctx),
+        device_name_(device_name) {}
+
+  // Converts `src` to a tensor (if it's not already one), and returns a new
+  // reference to the converted value.
+  //
+  // Args:
+  //   src: The object that should be converted to a Tensor.
+  //   dtype: The requested dtype.  Use `DT_INVALID` if the dtype should be
+  //     inferred from the `src` value (in which case `dtype` will be updated
+  //     in-place to be the actual dtype of the converted value).
+  //   used_fallback: Output parameter used to record whether the conversion
+  //     was done by falling back to the Python `tf.convert_to_tensor()`
+  //     function.  This is for testing/logging purposes only.  May be null.
+  //
+  // If `src` can't be converted to a tensor with the requested dtype, sets a
+  // Python exception and returns nullptr.
+  Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype,
+                           bool* used_fallback = nullptr) const;
+
+ private:
+  PyObject* py_eager_context_;
+  TFE_Context* ctx_;
+  const char* device_name_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
diff --git a/tensorflow/python/framework/python_tensor_converter_test.py b/tensorflow/python/framework/python_tensor_converter_test.py
new file mode 100644
index 00000000000..a29f87f3e23
--- /dev/null
+++ b/tensorflow/python/framework/python_tensor_converter_test.py
@@ -0,0 +1,208 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.python_tensor_converter."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+import numpy as np
+
+from tensorflow.core.framework import types_pb2
+from tensorflow.python import _pywrap_python_tensor_converter
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import indexed_slices
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class PythonTensorConverterTest(test_util.TensorFlowTestCase,
+                                parameterized.TestCase):
+
+  def setUp(self):
+    context.ensure_initialized()
+    super(PythonTensorConverterTest, self).setUp()
+
+  def makePythonTensorConverter(self):
+    return _pywrap_python_tensor_converter.PythonTensorConverter(
+        context.context())
+
+  #=============================================================================
+  # Convert int to tensor.
+
+  def testConvertIntWithInferredDType(self):
+    converter = self.makePythonTensorConverter()
+    result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INVALID)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, 12)
+    self.assertEqual(dtype, types_pb2.DT_INT32)
+    self.assertEqual(used_fallback, not context.executing_eagerly())
+
+  def testConvertIntWithExplicitDtype(self):
+    converter = self.makePythonTensorConverter()
+    result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INT64)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, 12)
+    self.assertEqual(dtype, types_pb2.DT_INT64)
+    self.assertEqual(used_fallback, not context.executing_eagerly())
+
+  def testConvertIntWithIncompatibleDtype(self):
+    converter = self.makePythonTensorConverter()
+    with self.assertRaisesRegex(
+        TypeError, "Expected string, got 3 of type 'int' instead."
+        "|Cannot convert 3 to EagerTensor of dtype string"):
+      converter.Convert(3, types_pb2.DT_STRING)
+
+  #=============================================================================
+  # Convert tensor to tensor.
+
+  def testConvertTensorWithInferredDType(self):
+    converter = self.makePythonTensorConverter()
+    result, dtype, used_fallback = converter.Convert(
+        constant_op.constant([1, 2, 3]), types_pb2.DT_INVALID)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [1, 2, 3])
+    self.assertEqual(dtype, types_pb2.DT_INT32)
+    self.assertFalse(used_fallback)
+
+  def testConvertTensorWithExplicitDtype(self):
+    converter = self.makePythonTensorConverter()
+    result, dtype, used_fallback = converter.Convert(
+        constant_op.constant([1, 2, 3], dtypes.int64), types_pb2.DT_INT64)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [1, 2, 3])
+    self.assertEqual(dtype, types_pb2.DT_INT64)
+    self.assertFalse(used_fallback)
+
+  def testConvertTensorWithIncorrectDtype(self):
+    converter = self.makePythonTensorConverter()
+    with self.assertRaises((TypeError, ValueError)):
+      converter.Convert(
+          constant_op.constant([1, 2, 3], dtypes.int32), types_pb2.DT_INT64)
+
+  #=============================================================================
+  # Convert list to tensor.
+
+  def testConvertListWithInferredDType(self):
+    converter = self.makePythonTensorConverter()
+    result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]],
+                                                     types_pb2.DT_INVALID)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
+    self.assertEqual(dtype, types_pb2.DT_INT32)
+    self.assertEqual(used_fallback, not context.executing_eagerly())
+
+  def testConvertListWithExplicitDtype(self):
+    converter = self.makePythonTensorConverter()
+    result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]],
+                                                     types_pb2.DT_INT64)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
+    self.assertEqual(dtype, types_pb2.DT_INT64)
+    self.assertEqual(used_fallback, not context.executing_eagerly())
+
+  def testConvertListWithIncompatibleDtype(self):
+    converter = self.makePythonTensorConverter()
+    with self.assertRaisesRegex(
+        TypeError, "Expected string, got .* of type 'int' instead."
+        "|Cannot convert .* to EagerTensor of dtype string"):
+      converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_STRING)
+
+  def testConvertListWithInconsistentDtype(self):
+    converter = self.makePythonTensorConverter()
+    with self.assertRaisesRegex(
+        (TypeError, ValueError),
+        "Can't convert Python sequence with mixed types to Tensor."
+        "|Failed to convert"):
+      converter.Convert([[1, 2], ["a", "b"]], types_pb2.DT_INVALID)
+
+  #=============================================================================
+  # Convert np.array to tensor.
+
+  def testConvertNumpyArrayWithInferredDType(self):
+    converter = self.makePythonTensorConverter()
+    x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
+    result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
+    self.assertEqual(dtype, types_pb2.DT_INT32)
+    self.assertEqual(used_fallback, not context.executing_eagerly())
+
+  def testConvertNumpyArrayWithExplicitDtype(self):
+    converter = self.makePythonTensorConverter()
+    x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
+    result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT64)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
+    self.assertEqual(dtype, types_pb2.DT_INT64)
+    self.assertEqual(used_fallback, not context.executing_eagerly())
+
+  def testConvertNumpyArrayWithIncompatibleDtype(self):
+    converter = self.makePythonTensorConverter()
+    x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
+    with self.assertRaises((ValueError, TypeError)):
+      converter.Convert(x, types_pb2.DT_STRING)
+
+  def testConvertNumpyArrayWithUnsupportedDtype(self):
+    converter = self.makePythonTensorConverter()
+    x = np.array([[1, 2], ["a", "b"]], np.object)
+    with self.assertRaises((ValueError, TypeError)):
+      converter.Convert(x, types_pb2.DT_INVALID)
+
+  #=============================================================================
+  # Convert IndexedSlices to tensor.
+
+  def testConvertIndexedSlicesWithInferredDType(self):
+    converter = self.makePythonTensorConverter()
+    x = indexed_slices.IndexedSlices(
+        constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
+        constant_op.constant([1], dtypes.int64, name="x_indices"),
+        constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
+    result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]])
+    self.assertEqual(dtype, types_pb2.DT_INT32)
+    self.assertTrue(used_fallback)
+
+  def testConvertIndexedSlicesWithExplicitDtype(self):
+    converter = self.makePythonTensorConverter()
+    x = indexed_slices.IndexedSlices(
+        constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
+        constant_op.constant([1], dtypes.int64, name="x_indices"),
+        constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
+    result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT32)
+    self.assertIsInstance(result, ops.Tensor)
+    self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]])
+    self.assertEqual(dtype, types_pb2.DT_INT32)
+    self.assertTrue(used_fallback)
+
+  def testConvertIndexedSlicesWithIncorrectDtype(self):
+    converter = self.makePythonTensorConverter()
+    x = indexed_slices.IndexedSlices(
+        constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
+        constant_op.constant([1], dtypes.int64, name="x_indices"),
+        constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
+    with self.assertRaises((ValueError, TypeError)):
+      converter.Convert(x, types_pb2.DT_FLOAT)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/framework/python_tensor_converter_wrapper.cc b/tensorflow/python/framework/python_tensor_converter_wrapper.cc
new file mode 100644
index 00000000000..33491869dc6
--- /dev/null
+++ b/tensorflow/python/framework/python_tensor_converter_wrapper.cc
@@ -0,0 +1,120 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Note: This library is only used by python_tensor_converter_test.  It is
+// not meant to be used in other circumstances.
+
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "pybind11/stl.h"
+#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "tensorflow/python/framework/python_tensor_converter.h"
+
+#if PY_MAJOR_VERSION < 3
+// Python 2.x:
+#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
+#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
+#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
+#else
+// Python 3.x:
+#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
+#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
+#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
+#endif
+
+namespace py = pybind11;
+
+namespace tensorflow {
+namespace {
+
+Safe_PyObjectPtr GetAttr_ThreadLocalData(PyObject* eager_context) {
+  static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_thread_local_data");
+  return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
+}
+
+Safe_PyObjectPtr GetAttr_ContextHandle(PyObject* eager_context) {
+  static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_context_handle");
+  return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
+}
+
+Safe_PyObjectPtr GetAttr_IsEager(PyObject* tld) {
+  static PyObject* attr = PY_STRING_INTERN_FROM_STRING("is_eager");
+  return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
+}
+
+Safe_PyObjectPtr GetAttr_DeviceName(PyObject* tld) {
+  static PyObject* attr = PY_STRING_INTERN_FROM_STRING("device_name");
+  return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
+}
+
+Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) {
+  static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
+  return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr));
+}
+
+PythonTensorConverter MakePythonTensorConverter(py::handle py_eager_context) {
+  Safe_PyObjectPtr tld = GetAttr_ThreadLocalData(py_eager_context.ptr());
+  if (!tld) throw py::error_already_set();
+
+  Safe_PyObjectPtr py_is_eager = GetAttr_IsEager(tld.get());
+  if (!py_is_eager) throw py::error_already_set();
+  bool is_eager = PyObject_IsTrue(py_is_eager.get());
+
+  // Initialize the eager context, if necessary.
+  TFE_Context* ctx = nullptr;
+  const char* device_name = nullptr;
+  if (is_eager) {
+    Safe_PyObjectPtr context_handle =
+        GetAttr_ContextHandle(py_eager_context.ptr());
+    if (!context_handle) throw py::error_already_set();
+    if (context_handle.get() == Py_None) {
+      throw std::runtime_error("Error retrieving context handle.");
+    }
+    Safe_PyObjectPtr py_device_name = GetAttr_DeviceName(tld.get());
+    if (!py_device_name) {
+      throw std::runtime_error("Error retrieving device name.");
+    }
+    device_name = TFE_GetPythonString(py_device_name.get());
+    ctx = reinterpret_cast<TFE_Context*>(
+        PyCapsule_GetPointer(context_handle.get(), nullptr));
+  }
+
+  return PythonTensorConverter(py_eager_context.ptr(), ctx, device_name);
+}
+
+py::handle Convert(tensorflow::PythonTensorConverter* self, py::handle obj,
+                   py::handle dtype) {
+  DataType dtype_enum = static_cast<DataType>(PY_INT_AS_LONG(dtype.ptr()));
+  bool used_fallback = false;
+  Safe_PyObjectPtr converted =
+      self->Convert(obj.ptr(), dtype_enum, &used_fallback);
+  if (!converted) throw py::error_already_set();
+
+  PyObject* result = PyTuple_New(3);
+  PyTuple_SET_ITEM(result, 0, converted.release());
+  PyTuple_SET_ITEM(result, 1, PY_INT_FROM_LONG(dtype_enum));
+  PyTuple_SET_ITEM(result, 2, used_fallback ? Py_True : Py_False);
+  Py_INCREF(PyTuple_GET_ITEM(result, 1));
+  Py_INCREF(PyTuple_GET_ITEM(result, 2));
+  return result;
+}
+
+}  // namespace
+}  // namespace tensorflow
+
+PYBIND11_MODULE(_pywrap_python_tensor_converter, m) {
+  py::class_<tensorflow::PythonTensorConverter>(m, "PythonTensorConverter")
+      .def(py::init(&tensorflow::MakePythonTensorConverter))
+      .def("Convert", tensorflow::Convert);
+}
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 222a6bb262b..af5b1a104f4 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -400,3 +400,6 @@ tensorflow::TensorHandle::Tensor
 
 [python_api_dispatcher] # python_api_dispatcher
 tensorflow::PythonAPIDispatcher
+
+[python_tensor_converter] # python_tensor_converter
+tensorflow::PythonTensorConverter