Add PythonTensorConverter class, which can be used in c++ to efficiently convert PyObjects to tensors.

PiperOrigin-RevId: 339155091
Change-Id: Icad20253f523e3ac3685d8d34c52e940f3889b57
This commit is contained in:
Edward Loper 2020-10-26 17:42:57 -07:00 committed by TensorFlower Gardener
parent 08e92a07fe
commit 3ac0818dea
8 changed files with 620 additions and 0 deletions

View File

@ -110,6 +110,7 @@ filegroup(
"abstract_function.h",
"abstract_operation.h",
"abstract_tensor_handle.h",
"c_api.h",
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",

View File

@ -1777,6 +1777,79 @@ tf_py_test(
],
)
cc_library(
name = "python_tensor_converter",
srcs = ["framework/python_tensor_converter.cc"],
hdrs = ["framework/python_tensor_converter.h"],
deps = [
":cpp_python_util",
":safe_pyobject_ptr",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:protos_all_cc",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/strings",
],
)
# Note: this target is only used by python_tensor_converter_test.
tf_python_pybind_extension(
name = "_pywrap_python_tensor_converter",
srcs = ["framework/python_tensor_converter_wrapper.cc"],
hdrs = [
"framework/python_tensor_converter.h",
"lib/core/numpy.h",
"//tensorflow/c:headers",
"//tensorflow/c/eager:pywrap_required_hdrs",
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
"//tensorflow/python/eager:pywrap_required_hdrs",
],
module_name = "_pywrap_python_tensor_converter",
deps = [
":safe_pyobject_ptr_required_hdrs",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@pybind11",
"//third_party/python_runtime:headers", # buildcleaner: keep
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:framework",
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
"//tensorflow/core:lib_headers_for_pybind",
"//third_party/py/numpy:headers",
"//tensorflow/c:pywrap_required_hdrs",
"@com_google_absl//absl/types:span",
] + if_static(
extra_deps = [
"//tensorflow/core/protobuf:eager_service_proto_cc",
"//tensorflow/core/protobuf:master_proto_cc",
"//tensorflow/core/protobuf:worker_proto_cc",
],
otherwise = [
"//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
"//tensorflow/core/protobuf:master_proto_cc_headers_only",
"//tensorflow/core/protobuf:worker_proto_cc_headers_only",
],
),
)
tf_py_test(
name = "python_tensor_converter_test",
srcs = ["framework/python_tensor_converter_test.py"],
python_version = "PY3",
tags = ["no_pip"],
deps = [
":_pywrap_python_tensor_converter",
":client_testlib",
],
)
py_library(
name = "framework_ops", # "ops" is already the name of a deprecated target
srcs = ["framework/ops.py"],
@ -6026,6 +6099,7 @@ pywrap_tensorflow_macro(
":pybind11_proto",
":python_api_dispatcher",
":python_op_gen",
":python_tensor_converter",
":safe_pyobject_ptr",
":tf_session_helper",
"//third_party/python_runtime:headers",
@ -6096,6 +6170,7 @@ filegroup(
":py_exception_registry", # py_exception_registry
":py_func_lib", # py_func
":python_api_dispatcher", # python_api_dispatcher
":python_tensor_converter", # python_tensor_converter
":python_op_gen", # python_op_gen
":safe_ptr", # checkpoint_reader
"//tensorflow/c:checkpoint_reader", # checkpoint_reader

View File

@ -72,6 +72,7 @@ cc_library(
"//tensorflow/python:py_seq_tensor",
"//tensorflow/python:py_util",
"//tensorflow/python:safe_ptr",
"//tensorflow/python:safe_pyobject_ptr",
"//tensorflow/python:stack_trace",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",

View File

@ -0,0 +1,136 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/framework/python_tensor_converter.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/util/util.h"
#if PY_MAJOR_VERSION < 3
// Python 2.x:
#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
#else
// Python 3.x:
#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
#endif
namespace tensorflow {
namespace {
// Returns `tensor.dtype._type_enum` as a DataType enum. Assumes that `tensor`
// is a python `Tensor` object.
//
// On error: sets a python AttributeError exception and returns DT_INVALID.
DataType DataTypeForTensor(PyObject* tensor) {
static PyObject* dtype_attr = PY_STRING_INTERN_FROM_STRING("dtype");
static PyObject* type_enum_attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
Safe_PyObjectPtr py_dtype(PyObject_GetAttr(tensor, dtype_attr));
if (!py_dtype) return DT_INVALID;
Safe_PyObjectPtr enum_field(PyObject_GetAttr(py_dtype.get(), type_enum_attr));
if (!enum_field) return DT_INVALID;
DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get()));
return result;
}
// Check that actual_dtype == expected_dtype. If not, set an exception and
// return false. (If expected_dtype is DT_INVALID, then instead simply update
// its value to `actual_dtype` and return true.)
bool CheckDType(DataType actual_dtype, DataType& expected_dtype) {
if (expected_dtype == DT_INVALID) {
expected_dtype = actual_dtype; // set output parameter.
} else if (expected_dtype != actual_dtype) {
PyErr_SetString(PyExc_TypeError,
absl::StrCat("Expected ", DataType_Name(expected_dtype),
" but got ", DataType_Name(actual_dtype))
.c_str());
return false;
}
return true;
}
} // namespace
Safe_PyObjectPtr PythonTensorConverter::Convert(PyObject* src, DataType& dtype,
bool* used_fallback) const {
// First, try converting `src` to a Tensor without calling back into Python.
if (ctx_) { // Eager mode
// TODO(b/164980194): Handle resource variables as well. (See
// ConvertToTensor function in pywrap_tfe_src.cc).
if (EagerTensor_CheckExact(src)) {
// `src` is already an eager tensor; check its type, and return it as-is.
if (!CheckDType(PyEagerTensor_Dtype(src), dtype)) return nullptr;
Py_INCREF(src);
return Safe_PyObjectPtr(src);
} else {
TFE_TensorHandle* handle =
tensorflow::ConvertToEagerTensor(ctx_, src, dtype, device_name_);
if (handle) {
Safe_PyObjectPtr result(EagerTensorFromHandle(handle));
if (!CheckDType(PyEagerTensor_Dtype(result.get()), dtype)) {
return nullptr;
}
return result;
} else {
PyErr_Clear();
}
}
} else { // Graph mode
if (swig::IsTensor(src)) {
DataType src_dtype = DataTypeForTensor(src);
if (src_dtype == DT_INVALID) return nullptr;
if (!CheckDType(src_dtype, dtype)) return nullptr;
Py_INCREF(src);
return Safe_PyObjectPtr(src);
}
}
// Fallback: use the Python tf.convert_to_tensor function.
// Currently this is used:
//
// * In Eager mode: for anything that's not already an Eager tensor, or
// handled by `tensorflow::ConvertToEagerTensor`. (At time of writing
// for this comment, ConvertToEagerTensor handles simple values like ints,
// nested lists of simple values, and numpy arrays.)
// * In graph mode: for anything that's not already a tensor.
//
// TODO(b/164980194) Reduce/eliminate cases where fallback is used.
if (used_fallback) *used_fallback = true;
static PyObject* convert_to_tensor =
swig::GetRegisteredPyObject("tf.convert_to_tensor");
if (!convert_to_tensor) return nullptr;
Safe_PyObjectPtr args(PyTuple_New(dtype == DT_INVALID ? 1 : 2));
Safe_PyObjectPtr kwargs(PyDict_New());
Py_INCREF(src);
PyTuple_SetItem(args.get(), 0, src);
if (dtype != DT_INVALID) {
PyTuple_SetItem(args.get(), 1, PyLong_FromLong(dtype));
}
PyDict_SetItemString(kwargs.get(), "ctx", py_eager_context_);
Safe_PyObjectPtr result(
PyObject_Call(convert_to_tensor, args.get(), kwargs.get()));
if (!result) return nullptr;
dtype = DataTypeForTensor(result.get()); // set output parameter.
if (dtype == DT_INVALID) return nullptr;
return result;
}
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
#include <Python.h>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
namespace tensorflow {
// Converts PyObject* values to Tensors.
//
// This converter attempts to convert values as efficiently as possible; but
// it has fallback paths to handle any PyObject* value for which tensor
// conversion is defined.
class PythonTensorConverter {
public:
// Constructs a new PythonTensorConverter.
//
// Note: the arguments to this constructor may change in the future, as
// we move more of python tensor conversion from the Python layer to the
// c++ layer.
//
// Args:
// py_eager_context: the value of context.context() from eager/context.py.
// ctx: The c++ eager context, or nullptr in graph mode.
// device_name: The current device name.
//
// All three argument values must remain alive until `this` is deleted.
PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx,
const char* device_name)
: py_eager_context_(py_eager_context),
ctx_(ctx),
device_name_(device_name) {}
// Converts `src` to a tensor (if it's not already one), and returns a new
// reference to the converted value.
//
// Args:
// src: The object that should be converted to a Tensor.
// dtype: The requested dtype. Use `DT_INVALID` if the dtype should be
// inferred from the `src` value (in which case `dtype` will be updated
// in-place to be the actual dtype of the converted value).
// used_fallback: Output parameter used to record whether the conversion
// was done by falling back to the Python `tf.convert_to_tensor()`
// function. This is for testing/logging purposes only. May be null.
//
// If `src` can't be converted to a tensor with the requested dtype, sets a
// Python exception and returns nullptr.
Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype,
bool* used_fallback = nullptr) const;
private:
PyObject* py_eager_context_;
TFE_Context* ctx_;
const char* device_name_;
};
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_

View File

@ -0,0 +1,208 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.python_tensor_converter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import types_pb2
from tensorflow.python import _pywrap_python_tensor_converter
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class PythonTensorConverterTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def setUp(self):
context.ensure_initialized()
super(PythonTensorConverterTest, self).setUp()
def makePythonTensorConverter(self):
return _pywrap_python_tensor_converter.PythonTensorConverter(
context.context())
#=============================================================================
# Convert int to tensor.
def testConvertIntWithInferredDType(self):
converter = self.makePythonTensorConverter()
result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INVALID)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, 12)
self.assertEqual(dtype, types_pb2.DT_INT32)
self.assertEqual(used_fallback, not context.executing_eagerly())
def testConvertIntWithExplicitDtype(self):
converter = self.makePythonTensorConverter()
result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INT64)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, 12)
self.assertEqual(dtype, types_pb2.DT_INT64)
self.assertEqual(used_fallback, not context.executing_eagerly())
def testConvertIntWithIncompatibleDtype(self):
converter = self.makePythonTensorConverter()
with self.assertRaisesRegex(
TypeError, "Expected string, got 3 of type 'int' instead."
"|Cannot convert 3 to EagerTensor of dtype string"):
converter.Convert(3, types_pb2.DT_STRING)
#=============================================================================
# Convert tensor to tensor.
def testConvertTensorWithInferredDType(self):
converter = self.makePythonTensorConverter()
result, dtype, used_fallback = converter.Convert(
constant_op.constant([1, 2, 3]), types_pb2.DT_INVALID)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [1, 2, 3])
self.assertEqual(dtype, types_pb2.DT_INT32)
self.assertFalse(used_fallback)
def testConvertTensorWithExplicitDtype(self):
converter = self.makePythonTensorConverter()
result, dtype, used_fallback = converter.Convert(
constant_op.constant([1, 2, 3], dtypes.int64), types_pb2.DT_INT64)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [1, 2, 3])
self.assertEqual(dtype, types_pb2.DT_INT64)
self.assertFalse(used_fallback)
def testConvertTensorWithIncorrectDtype(self):
converter = self.makePythonTensorConverter()
with self.assertRaises((TypeError, ValueError)):
converter.Convert(
constant_op.constant([1, 2, 3], dtypes.int32), types_pb2.DT_INT64)
#=============================================================================
# Convert list to tensor.
def testConvertListWithInferredDType(self):
converter = self.makePythonTensorConverter()
result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]],
types_pb2.DT_INVALID)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
self.assertEqual(dtype, types_pb2.DT_INT32)
self.assertEqual(used_fallback, not context.executing_eagerly())
def testConvertListWithExplicitDtype(self):
converter = self.makePythonTensorConverter()
result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]],
types_pb2.DT_INT64)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
self.assertEqual(dtype, types_pb2.DT_INT64)
self.assertEqual(used_fallback, not context.executing_eagerly())
def testConvertListWithIncompatibleDtype(self):
converter = self.makePythonTensorConverter()
with self.assertRaisesRegex(
TypeError, "Expected string, got .* of type 'int' instead."
"|Cannot convert .* to EagerTensor of dtype string"):
converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_STRING)
def testConvertListWithInconsistentDtype(self):
converter = self.makePythonTensorConverter()
with self.assertRaisesRegex(
(TypeError, ValueError),
"Can't convert Python sequence with mixed types to Tensor."
"|Failed to convert"):
converter.Convert([[1, 2], ["a", "b"]], types_pb2.DT_INVALID)
#=============================================================================
# Convert np.array to tensor.
def testConvertNumpyArrayWithInferredDType(self):
converter = self.makePythonTensorConverter()
x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
self.assertEqual(dtype, types_pb2.DT_INT32)
self.assertEqual(used_fallback, not context.executing_eagerly())
def testConvertNumpyArrayWithExplicitDtype(self):
converter = self.makePythonTensorConverter()
x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT64)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]])
self.assertEqual(dtype, types_pb2.DT_INT64)
self.assertEqual(used_fallback, not context.executing_eagerly())
def testConvertNumpyArrayWithIncompatibleDtype(self):
converter = self.makePythonTensorConverter()
x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)
with self.assertRaises((ValueError, TypeError)):
converter.Convert(x, types_pb2.DT_STRING)
def testConvertNumpyArrayWithUnsupportedDtype(self):
converter = self.makePythonTensorConverter()
x = np.array([[1, 2], ["a", "b"]], np.object)
with self.assertRaises((ValueError, TypeError)):
converter.Convert(x, types_pb2.DT_INVALID)
#=============================================================================
# Convert IndexedSlices to tensor.
def testConvertIndexedSlicesWithInferredDType(self):
converter = self.makePythonTensorConverter()
x = indexed_slices.IndexedSlices(
constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
constant_op.constant([1], dtypes.int64, name="x_indices"),
constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]])
self.assertEqual(dtype, types_pb2.DT_INT32)
self.assertTrue(used_fallback)
def testConvertIndexedSlicesWithExplicitDtype(self):
converter = self.makePythonTensorConverter()
x = indexed_slices.IndexedSlices(
constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
constant_op.constant([1], dtypes.int64, name="x_indices"),
constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT32)
self.assertIsInstance(result, ops.Tensor)
self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]])
self.assertEqual(dtype, types_pb2.DT_INT32)
self.assertTrue(used_fallback)
def testConvertIndexedSlicesWithIncorrectDtype(self):
converter = self.makePythonTensorConverter()
x = indexed_slices.IndexedSlices(
constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"),
constant_op.constant([1], dtypes.int64, name="x_indices"),
constant_op.constant([3, 3], dtypes.int64, name="x_shape"))
with self.assertRaises((ValueError, TypeError)):
converter.Convert(x, types_pb2.DT_FLOAT)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,120 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Note: This library is only used by python_tensor_converter_test. It is
// not meant to be used in other circumstances.
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/framework/python_tensor_converter.h"
#if PY_MAJOR_VERSION < 3
// Python 2.x:
#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
#else
// Python 3.x:
#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
#endif
namespace py = pybind11;
namespace tensorflow {
namespace {
Safe_PyObjectPtr GetAttr_ThreadLocalData(PyObject* eager_context) {
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_thread_local_data");
return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
}
Safe_PyObjectPtr GetAttr_ContextHandle(PyObject* eager_context) {
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_context_handle");
return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr));
}
Safe_PyObjectPtr GetAttr_IsEager(PyObject* tld) {
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("is_eager");
return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
}
Safe_PyObjectPtr GetAttr_DeviceName(PyObject* tld) {
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("device_name");
return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr));
}
Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) {
static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr));
}
PythonTensorConverter MakePythonTensorConverter(py::handle py_eager_context) {
Safe_PyObjectPtr tld = GetAttr_ThreadLocalData(py_eager_context.ptr());
if (!tld) throw py::error_already_set();
Safe_PyObjectPtr py_is_eager = GetAttr_IsEager(tld.get());
if (!py_is_eager) throw py::error_already_set();
bool is_eager = PyObject_IsTrue(py_is_eager.get());
// Initialize the eager context, if necessary.
TFE_Context* ctx = nullptr;
const char* device_name = nullptr;
if (is_eager) {
Safe_PyObjectPtr context_handle =
GetAttr_ContextHandle(py_eager_context.ptr());
if (!context_handle) throw py::error_already_set();
if (context_handle.get() == Py_None) {
throw std::runtime_error("Error retrieving context handle.");
}
Safe_PyObjectPtr py_device_name = GetAttr_DeviceName(tld.get());
if (!py_device_name) {
throw std::runtime_error("Error retrieving device name.");
}
device_name = TFE_GetPythonString(py_device_name.get());
ctx = reinterpret_cast<TFE_Context*>(
PyCapsule_GetPointer(context_handle.get(), nullptr));
}
return PythonTensorConverter(py_eager_context.ptr(), ctx, device_name);
}
py::handle Convert(tensorflow::PythonTensorConverter* self, py::handle obj,
py::handle dtype) {
DataType dtype_enum = static_cast<DataType>(PY_INT_AS_LONG(dtype.ptr()));
bool used_fallback = false;
Safe_PyObjectPtr converted =
self->Convert(obj.ptr(), dtype_enum, &used_fallback);
if (!converted) throw py::error_already_set();
PyObject* result = PyTuple_New(3);
PyTuple_SET_ITEM(result, 0, converted.release());
PyTuple_SET_ITEM(result, 1, PY_INT_FROM_LONG(dtype_enum));
PyTuple_SET_ITEM(result, 2, used_fallback ? Py_True : Py_False);
Py_INCREF(PyTuple_GET_ITEM(result, 1));
Py_INCREF(PyTuple_GET_ITEM(result, 2));
return result;
}
} // namespace
} // namespace tensorflow
PYBIND11_MODULE(_pywrap_python_tensor_converter, m) {
py::class_<tensorflow::PythonTensorConverter>(m, "PythonTensorConverter")
.def(py::init(&tensorflow::MakePythonTensorConverter))
.def("Convert", tensorflow::Convert);
}

View File

@ -400,3 +400,6 @@ tensorflow::TensorHandle::Tensor
[python_api_dispatcher] # python_api_dispatcher
tensorflow::PythonAPIDispatcher
[python_tensor_converter] # python_tensor_converter
tensorflow::PythonTensorConverter