From 3ac0818dea42c6699cda3cea26657f606f11a3cc Mon Sep 17 00:00:00 2001 From: Edward Loper <edloper@google.com> Date: Mon, 26 Oct 2020 17:42:57 -0700 Subject: [PATCH] Add PythonTensorConverter class, which can be used in c++ to efficiently convert PyObjects to tensors. PiperOrigin-RevId: 339155091 Change-Id: Icad20253f523e3ac3685d8d34c52e940f3889b57 --- tensorflow/c/eager/BUILD | 1 + tensorflow/python/BUILD | 75 +++++++ tensorflow/python/eager/BUILD | 1 + .../framework/python_tensor_converter.cc | 136 ++++++++++++ .../framework/python_tensor_converter.h | 76 +++++++ .../framework/python_tensor_converter_test.py | 208 ++++++++++++++++++ .../python_tensor_converter_wrapper.cc | 120 ++++++++++ .../tools/def_file_filter/symbols_pybind.txt | 3 + 8 files changed, 620 insertions(+) create mode 100644 tensorflow/python/framework/python_tensor_converter.cc create mode 100644 tensorflow/python/framework/python_tensor_converter.h create mode 100644 tensorflow/python/framework/python_tensor_converter_test.py create mode 100644 tensorflow/python/framework/python_tensor_converter_wrapper.cc diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c44d0ee6873..fa0fdbae861 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -110,6 +110,7 @@ filegroup( "abstract_function.h", "abstract_operation.h", "abstract_tensor_handle.h", + "c_api.h", "c_api_experimental.h", "c_api_internal.h", "c_api_unified_experimental.h", diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 89c3719b943..903c2449715 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1777,6 +1777,79 @@ tf_py_test( ], ) +cc_library( + name = "python_tensor_converter", + srcs = ["framework/python_tensor_converter.cc"], + hdrs = ["framework/python_tensor_converter.h"], + deps = [ + ":cpp_python_util", + ":safe_pyobject_ptr", + "//tensorflow/c/eager:c_api", + "//tensorflow/core:protos_all_cc", + "//tensorflow/python/eager:pywrap_tfe_lib", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/strings", + ], +) + +# Note: this target is only used by python_tensor_converter_test. +tf_python_pybind_extension( + name = "_pywrap_python_tensor_converter", + srcs = ["framework/python_tensor_converter_wrapper.cc"], + hdrs = [ + "framework/python_tensor_converter.h", + "lib/core/numpy.h", + "//tensorflow/c:headers", + "//tensorflow/c/eager:pywrap_required_hdrs", + "//tensorflow/c/experimental/ops:pywrap_required_hdrs", + "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", + "//tensorflow/python/eager:pywrap_required_hdrs", + ], + module_name = "_pywrap_python_tensor_converter", + deps = [ + ":safe_pyobject_ptr_required_hdrs", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@pybind11", + "//third_party/python_runtime:headers", # buildcleaner: keep + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:lib", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:core_cpu_headers_lib", + "//tensorflow/core:lib_headers_for_pybind", + "//third_party/py/numpy:headers", + "//tensorflow/c:pywrap_required_hdrs", + "@com_google_absl//absl/types:span", + ] + if_static( + extra_deps = [ + "//tensorflow/core/protobuf:eager_service_proto_cc", + "//tensorflow/core/protobuf:master_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", + ], + otherwise = [ + "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", + "//tensorflow/core/protobuf:master_proto_cc_headers_only", + "//tensorflow/core/protobuf:worker_proto_cc_headers_only", + ], + ), +) + +tf_py_test( + name = "python_tensor_converter_test", + srcs = ["framework/python_tensor_converter_test.py"], + python_version = "PY3", + tags = ["no_pip"], + deps = [ + ":_pywrap_python_tensor_converter", + ":client_testlib", + ], +) + py_library( name = "framework_ops", # "ops" is already the name of a deprecated target srcs = ["framework/ops.py"], @@ -6026,6 +6099,7 @@ pywrap_tensorflow_macro( ":pybind11_proto", ":python_api_dispatcher", ":python_op_gen", + ":python_tensor_converter", ":safe_pyobject_ptr", ":tf_session_helper", "//third_party/python_runtime:headers", @@ -6096,6 +6170,7 @@ filegroup( ":py_exception_registry", # py_exception_registry ":py_func_lib", # py_func ":python_api_dispatcher", # python_api_dispatcher + ":python_tensor_converter", # python_tensor_converter ":python_op_gen", # python_op_gen ":safe_ptr", # checkpoint_reader "//tensorflow/c:checkpoint_reader", # checkpoint_reader diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index fed430a2f8e..789e9419d9e 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/python:py_seq_tensor", "//tensorflow/python:py_util", "//tensorflow/python:safe_ptr", + "//tensorflow/python:safe_pyobject_ptr", "//tensorflow/python:stack_trace", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", diff --git a/tensorflow/python/framework/python_tensor_converter.cc b/tensorflow/python/framework/python_tensor_converter.cc new file mode 100644 index 00000000000..f18c8a8c681 --- /dev/null +++ b/tensorflow/python/framework/python_tensor_converter.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/framework/python_tensor_converter.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/python/eager/pywrap_tensor.h" +#include "tensorflow/python/eager/pywrap_tfe.h" +#include "tensorflow/python/util/util.h" + +#if PY_MAJOR_VERSION < 3 +// Python 2.x: +#define PY_INT_AS_LONG(x) (PyInt_AsLong(x)) +#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x)) +#else +// Python 3.x: +#define PY_INT_AS_LONG(x) (PyLong_AsLong(x)) +#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x)) +#endif + +namespace tensorflow { +namespace { + +// Returns `tensor.dtype._type_enum` as a DataType enum. Assumes that `tensor` +// is a python `Tensor` object. +// +// On error: sets a python AttributeError exception and returns DT_INVALID. +DataType DataTypeForTensor(PyObject* tensor) { + static PyObject* dtype_attr = PY_STRING_INTERN_FROM_STRING("dtype"); + static PyObject* type_enum_attr = PY_STRING_INTERN_FROM_STRING("_type_enum"); + + Safe_PyObjectPtr py_dtype(PyObject_GetAttr(tensor, dtype_attr)); + if (!py_dtype) return DT_INVALID; + + Safe_PyObjectPtr enum_field(PyObject_GetAttr(py_dtype.get(), type_enum_attr)); + if (!enum_field) return DT_INVALID; + + DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get())); + return result; +} + +// Check that actual_dtype == expected_dtype. If not, set an exception and +// return false. (If expected_dtype is DT_INVALID, then instead simply update +// its value to `actual_dtype` and return true.) +bool CheckDType(DataType actual_dtype, DataType& expected_dtype) { + if (expected_dtype == DT_INVALID) { + expected_dtype = actual_dtype; // set output parameter. + } else if (expected_dtype != actual_dtype) { + PyErr_SetString(PyExc_TypeError, + absl::StrCat("Expected ", DataType_Name(expected_dtype), + " but got ", DataType_Name(actual_dtype)) + .c_str()); + return false; + } + return true; +} + +} // namespace + +Safe_PyObjectPtr PythonTensorConverter::Convert(PyObject* src, DataType& dtype, + bool* used_fallback) const { + // First, try converting `src` to a Tensor without calling back into Python. + if (ctx_) { // Eager mode + // TODO(b/164980194): Handle resource variables as well. (See + // ConvertToTensor function in pywrap_tfe_src.cc). + if (EagerTensor_CheckExact(src)) { + // `src` is already an eager tensor; check its type, and return it as-is. + if (!CheckDType(PyEagerTensor_Dtype(src), dtype)) return nullptr; + Py_INCREF(src); + return Safe_PyObjectPtr(src); + } else { + TFE_TensorHandle* handle = + tensorflow::ConvertToEagerTensor(ctx_, src, dtype, device_name_); + if (handle) { + Safe_PyObjectPtr result(EagerTensorFromHandle(handle)); + if (!CheckDType(PyEagerTensor_Dtype(result.get()), dtype)) { + return nullptr; + } + return result; + } else { + PyErr_Clear(); + } + } + } else { // Graph mode + if (swig::IsTensor(src)) { + DataType src_dtype = DataTypeForTensor(src); + if (src_dtype == DT_INVALID) return nullptr; + if (!CheckDType(src_dtype, dtype)) return nullptr; + Py_INCREF(src); + return Safe_PyObjectPtr(src); + } + } + + // Fallback: use the Python tf.convert_to_tensor function. + // Currently this is used: + // + // * In Eager mode: for anything that's not already an Eager tensor, or + // handled by `tensorflow::ConvertToEagerTensor`. (At time of writing + // for this comment, ConvertToEagerTensor handles simple values like ints, + // nested lists of simple values, and numpy arrays.) + // * In graph mode: for anything that's not already a tensor. + // + // TODO(b/164980194) Reduce/eliminate cases where fallback is used. + if (used_fallback) *used_fallback = true; + static PyObject* convert_to_tensor = + swig::GetRegisteredPyObject("tf.convert_to_tensor"); + if (!convert_to_tensor) return nullptr; + + Safe_PyObjectPtr args(PyTuple_New(dtype == DT_INVALID ? 1 : 2)); + Safe_PyObjectPtr kwargs(PyDict_New()); + Py_INCREF(src); + PyTuple_SetItem(args.get(), 0, src); + if (dtype != DT_INVALID) { + PyTuple_SetItem(args.get(), 1, PyLong_FromLong(dtype)); + } + PyDict_SetItemString(kwargs.get(), "ctx", py_eager_context_); + Safe_PyObjectPtr result( + PyObject_Call(convert_to_tensor, args.get(), kwargs.get())); + if (!result) return nullptr; + dtype = DataTypeForTensor(result.get()); // set output parameter. + if (dtype == DT_INVALID) return nullptr; + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/python/framework/python_tensor_converter.h b/tensorflow/python/framework/python_tensor_converter.h new file mode 100644 index 00000000000..faf1793d4cd --- /dev/null +++ b/tensorflow/python/framework/python_tensor_converter.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ +#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ + +#include <Python.h> + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" + +namespace tensorflow { + +// Converts PyObject* values to Tensors. +// +// This converter attempts to convert values as efficiently as possible; but +// it has fallback paths to handle any PyObject* value for which tensor +// conversion is defined. +class PythonTensorConverter { + public: + // Constructs a new PythonTensorConverter. + // + // Note: the arguments to this constructor may change in the future, as + // we move more of python tensor conversion from the Python layer to the + // c++ layer. + // + // Args: + // py_eager_context: the value of context.context() from eager/context.py. + // ctx: The c++ eager context, or nullptr in graph mode. + // device_name: The current device name. + // + // All three argument values must remain alive until `this` is deleted. + PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx, + const char* device_name) + : py_eager_context_(py_eager_context), + ctx_(ctx), + device_name_(device_name) {} + + // Converts `src` to a tensor (if it's not already one), and returns a new + // reference to the converted value. + // + // Args: + // src: The object that should be converted to a Tensor. + // dtype: The requested dtype. Use `DT_INVALID` if the dtype should be + // inferred from the `src` value (in which case `dtype` will be updated + // in-place to be the actual dtype of the converted value). + // used_fallback: Output parameter used to record whether the conversion + // was done by falling back to the Python `tf.convert_to_tensor()` + // function. This is for testing/logging purposes only. May be null. + // + // If `src` can't be converted to a tensor with the requested dtype, sets a + // Python exception and returns nullptr. + Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype, + bool* used_fallback = nullptr) const; + + private: + PyObject* py_eager_context_; + TFE_Context* ctx_; + const char* device_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ diff --git a/tensorflow/python/framework/python_tensor_converter_test.py b/tensorflow/python/framework/python_tensor_converter_test.py new file mode 100644 index 00000000000..a29f87f3e23 --- /dev/null +++ b/tensorflow/python/framework/python_tensor_converter_test.py @@ -0,0 +1,208 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.python_tensor_converter.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np + +from tensorflow.core.framework import types_pb2 +from tensorflow.python import _pywrap_python_tensor_converter +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import indexed_slices +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +@test_util.run_all_in_graph_and_eager_modes +class PythonTensorConverterTest(test_util.TensorFlowTestCase, + parameterized.TestCase): + + def setUp(self): + context.ensure_initialized() + super(PythonTensorConverterTest, self).setUp() + + def makePythonTensorConverter(self): + return _pywrap_python_tensor_converter.PythonTensorConverter( + context.context()) + + #============================================================================= + # Convert int to tensor. + + def testConvertIntWithInferredDType(self): + converter = self.makePythonTensorConverter() + result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INVALID) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, 12) + self.assertEqual(dtype, types_pb2.DT_INT32) + self.assertEqual(used_fallback, not context.executing_eagerly()) + + def testConvertIntWithExplicitDtype(self): + converter = self.makePythonTensorConverter() + result, dtype, used_fallback = converter.Convert(12, types_pb2.DT_INT64) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, 12) + self.assertEqual(dtype, types_pb2.DT_INT64) + self.assertEqual(used_fallback, not context.executing_eagerly()) + + def testConvertIntWithIncompatibleDtype(self): + converter = self.makePythonTensorConverter() + with self.assertRaisesRegex( + TypeError, "Expected string, got 3 of type 'int' instead." + "|Cannot convert 3 to EagerTensor of dtype string"): + converter.Convert(3, types_pb2.DT_STRING) + + #============================================================================= + # Convert tensor to tensor. + + def testConvertTensorWithInferredDType(self): + converter = self.makePythonTensorConverter() + result, dtype, used_fallback = converter.Convert( + constant_op.constant([1, 2, 3]), types_pb2.DT_INVALID) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [1, 2, 3]) + self.assertEqual(dtype, types_pb2.DT_INT32) + self.assertFalse(used_fallback) + + def testConvertTensorWithExplicitDtype(self): + converter = self.makePythonTensorConverter() + result, dtype, used_fallback = converter.Convert( + constant_op.constant([1, 2, 3], dtypes.int64), types_pb2.DT_INT64) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [1, 2, 3]) + self.assertEqual(dtype, types_pb2.DT_INT64) + self.assertFalse(used_fallback) + + def testConvertTensorWithIncorrectDtype(self): + converter = self.makePythonTensorConverter() + with self.assertRaises((TypeError, ValueError)): + converter.Convert( + constant_op.constant([1, 2, 3], dtypes.int32), types_pb2.DT_INT64) + + #============================================================================= + # Convert list to tensor. + + def testConvertListWithInferredDType(self): + converter = self.makePythonTensorConverter() + result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]], + types_pb2.DT_INVALID) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) + self.assertEqual(dtype, types_pb2.DT_INT32) + self.assertEqual(used_fallback, not context.executing_eagerly()) + + def testConvertListWithExplicitDtype(self): + converter = self.makePythonTensorConverter() + result, dtype, used_fallback = converter.Convert([[1, 2, 3], [4, 5, 6]], + types_pb2.DT_INT64) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) + self.assertEqual(dtype, types_pb2.DT_INT64) + self.assertEqual(used_fallback, not context.executing_eagerly()) + + def testConvertListWithIncompatibleDtype(self): + converter = self.makePythonTensorConverter() + with self.assertRaisesRegex( + TypeError, "Expected string, got .* of type 'int' instead." + "|Cannot convert .* to EagerTensor of dtype string"): + converter.Convert([[1, 2, 3], [4, 5, 6]], types_pb2.DT_STRING) + + def testConvertListWithInconsistentDtype(self): + converter = self.makePythonTensorConverter() + with self.assertRaisesRegex( + (TypeError, ValueError), + "Can't convert Python sequence with mixed types to Tensor." + "|Failed to convert"): + converter.Convert([[1, 2], ["a", "b"]], types_pb2.DT_INVALID) + + #============================================================================= + # Convert np.array to tensor. + + def testConvertNumpyArrayWithInferredDType(self): + converter = self.makePythonTensorConverter() + x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) + result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) + self.assertEqual(dtype, types_pb2.DT_INT32) + self.assertEqual(used_fallback, not context.executing_eagerly()) + + def testConvertNumpyArrayWithExplicitDtype(self): + converter = self.makePythonTensorConverter() + x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) + result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT64) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [[1, 2, 3], [4, 5, 6]]) + self.assertEqual(dtype, types_pb2.DT_INT64) + self.assertEqual(used_fallback, not context.executing_eagerly()) + + def testConvertNumpyArrayWithIncompatibleDtype(self): + converter = self.makePythonTensorConverter() + x = np.array([[1, 2, 3], [4, 5, 6]], np.int32) + with self.assertRaises((ValueError, TypeError)): + converter.Convert(x, types_pb2.DT_STRING) + + def testConvertNumpyArrayWithUnsupportedDtype(self): + converter = self.makePythonTensorConverter() + x = np.array([[1, 2], ["a", "b"]], np.object) + with self.assertRaises((ValueError, TypeError)): + converter.Convert(x, types_pb2.DT_INVALID) + + #============================================================================= + # Convert IndexedSlices to tensor. + + def testConvertIndexedSlicesWithInferredDType(self): + converter = self.makePythonTensorConverter() + x = indexed_slices.IndexedSlices( + constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"), + constant_op.constant([1], dtypes.int64, name="x_indices"), + constant_op.constant([3, 3], dtypes.int64, name="x_shape")) + result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INVALID) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]]) + self.assertEqual(dtype, types_pb2.DT_INT32) + self.assertTrue(used_fallback) + + def testConvertIndexedSlicesWithExplicitDtype(self): + converter = self.makePythonTensorConverter() + x = indexed_slices.IndexedSlices( + constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"), + constant_op.constant([1], dtypes.int64, name="x_indices"), + constant_op.constant([3, 3], dtypes.int64, name="x_shape")) + result, dtype, used_fallback = converter.Convert(x, types_pb2.DT_INT32) + self.assertIsInstance(result, ops.Tensor) + self.assertAllEqual(result, [[0, 0, 0], [1, 2, 3], [0, 0, 0]]) + self.assertEqual(dtype, types_pb2.DT_INT32) + self.assertTrue(used_fallback) + + def testConvertIndexedSlicesWithIncorrectDtype(self): + converter = self.makePythonTensorConverter() + x = indexed_slices.IndexedSlices( + constant_op.constant([[1, 2, 3]], dtypes.int32, name="x_values"), + constant_op.constant([1], dtypes.int64, name="x_indices"), + constant_op.constant([3, 3], dtypes.int64, name="x_shape")) + with self.assertRaises((ValueError, TypeError)): + converter.Convert(x, types_pb2.DT_FLOAT) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/python_tensor_converter_wrapper.cc b/tensorflow/python/framework/python_tensor_converter_wrapper.cc new file mode 100644 index 00000000000..33491869dc6 --- /dev/null +++ b/tensorflow/python/framework/python_tensor_converter_wrapper.cc @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Note: This library is only used by python_tensor_converter_test. It is +// not meant to be used in other circumstances. + +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "tensorflow/python/eager/pywrap_tfe.h" +#include "tensorflow/python/framework/python_tensor_converter.h" + +#if PY_MAJOR_VERSION < 3 +// Python 2.x: +#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x)) +#define PY_INT_AS_LONG(x) (PyInt_AsLong(x)) +#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x)) +#else +// Python 3.x: +#define PY_INT_AS_LONG(x) (PyLong_AsLong(x)) +#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x)) +#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x)) +#endif + +namespace py = pybind11; + +namespace tensorflow { +namespace { + +Safe_PyObjectPtr GetAttr_ThreadLocalData(PyObject* eager_context) { + static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_thread_local_data"); + return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr)); +} + +Safe_PyObjectPtr GetAttr_ContextHandle(PyObject* eager_context) { + static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_context_handle"); + return Safe_PyObjectPtr(PyObject_GetAttr(eager_context, attr)); +} + +Safe_PyObjectPtr GetAttr_IsEager(PyObject* tld) { + static PyObject* attr = PY_STRING_INTERN_FROM_STRING("is_eager"); + return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr)); +} + +Safe_PyObjectPtr GetAttr_DeviceName(PyObject* tld) { + static PyObject* attr = PY_STRING_INTERN_FROM_STRING("device_name"); + return Safe_PyObjectPtr(PyObject_GetAttr(tld, attr)); +} + +Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) { + static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum"); + return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr)); +} + +PythonTensorConverter MakePythonTensorConverter(py::handle py_eager_context) { + Safe_PyObjectPtr tld = GetAttr_ThreadLocalData(py_eager_context.ptr()); + if (!tld) throw py::error_already_set(); + + Safe_PyObjectPtr py_is_eager = GetAttr_IsEager(tld.get()); + if (!py_is_eager) throw py::error_already_set(); + bool is_eager = PyObject_IsTrue(py_is_eager.get()); + + // Initialize the eager context, if necessary. + TFE_Context* ctx = nullptr; + const char* device_name = nullptr; + if (is_eager) { + Safe_PyObjectPtr context_handle = + GetAttr_ContextHandle(py_eager_context.ptr()); + if (!context_handle) throw py::error_already_set(); + if (context_handle.get() == Py_None) { + throw std::runtime_error("Error retrieving context handle."); + } + Safe_PyObjectPtr py_device_name = GetAttr_DeviceName(tld.get()); + if (!py_device_name) { + throw std::runtime_error("Error retrieving device name."); + } + device_name = TFE_GetPythonString(py_device_name.get()); + ctx = reinterpret_cast<TFE_Context*>( + PyCapsule_GetPointer(context_handle.get(), nullptr)); + } + + return PythonTensorConverter(py_eager_context.ptr(), ctx, device_name); +} + +py::handle Convert(tensorflow::PythonTensorConverter* self, py::handle obj, + py::handle dtype) { + DataType dtype_enum = static_cast<DataType>(PY_INT_AS_LONG(dtype.ptr())); + bool used_fallback = false; + Safe_PyObjectPtr converted = + self->Convert(obj.ptr(), dtype_enum, &used_fallback); + if (!converted) throw py::error_already_set(); + + PyObject* result = PyTuple_New(3); + PyTuple_SET_ITEM(result, 0, converted.release()); + PyTuple_SET_ITEM(result, 1, PY_INT_FROM_LONG(dtype_enum)); + PyTuple_SET_ITEM(result, 2, used_fallback ? Py_True : Py_False); + Py_INCREF(PyTuple_GET_ITEM(result, 1)); + Py_INCREF(PyTuple_GET_ITEM(result, 2)); + return result; +} + +} // namespace +} // namespace tensorflow + +PYBIND11_MODULE(_pywrap_python_tensor_converter, m) { + py::class_<tensorflow::PythonTensorConverter>(m, "PythonTensorConverter") + .def(py::init(&tensorflow::MakePythonTensorConverter)) + .def("Convert", tensorflow::Convert); +} diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 222a6bb262b..af5b1a104f4 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -400,3 +400,6 @@ tensorflow::TensorHandle::Tensor [python_api_dispatcher] # python_api_dispatcher tensorflow::PythonAPIDispatcher + +[python_tensor_converter] # python_tensor_converter +tensorflow::PythonTensorConverter