From 56d41be853301f21a56031ec8dc8420c472d698e Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Thu, 19 Dec 2019 11:59:05 -0800 Subject: [PATCH] Moved custom NumPy data types to :ndarray_tensor_types Prior to this change qint*, quint* and resource were defined in Python as a single-field struct arrays, e.g. _np_qint8 = np.dtype([("qint8", np.int8)]) Several TensorFlow functions had to special-case struct arrays (which otherwise have type NPY_VOID) and infer the real dtype from struct fields. Having these data types defined and handled in C++ allows to minimize magic on the Python/C++ boundary. PiperOrigin-RevId: 286436332 Change-Id: I20581b6999efbea02aa0efeccaa2d8889ceadaf2 --- tensorflow/python/BUILD | 50 ++- tensorflow/python/data/util/BUILD | 1 + tensorflow/python/eager/BUILD | 1 - tensorflow/python/eager/pywrap_tensor.cc | 7 +- tensorflow/python/framework/dtypes.cc | 13 - tensorflow/python/framework/dtypes.py | 33 +- .../python/kernel_tests/py_func_test.py | 2 +- tensorflow/python/lib/core/bfloat16.cc | 10 +- tensorflow/python/lib/core/bfloat16_test.py | 4 +- .../python/lib/core/bfloat16_wrapper.cc | 24 ++ tensorflow/python/lib/core/ndarray_tensor.cc | 201 +++++++++++- .../python/lib/core/ndarray_tensor_bridge.cc | 102 ++++++- .../python/lib/core/ndarray_tensor_bridge.h | 4 + .../python/lib/core/ndarray_tensor_types.cc | 287 ------------------ .../python/lib/core/ndarray_tensor_types.h | 65 ---- .../def_file_filter/def_file_filter.py.tpl | 2 +- .../tools/def_file_filter/symbols_pybind.txt | 14 +- 17 files changed, 364 insertions(+), 456 deletions(-) create mode 100644 tensorflow/python/lib/core/bfloat16_wrapper.cc delete mode 100644 tensorflow/python/lib/core/ndarray_tensor_types.cc delete mode 100644 tensorflow/python/lib/core/ndarray_tensor_types.h diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index db2222c5cd8..47e989341e0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -394,6 +394,8 @@ cc_library( srcs = ["lib/core/numpy.cc"], hdrs = ["lib/core/numpy.h"], deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], @@ -405,12 +407,24 @@ cc_library( hdrs = ["lib/core/bfloat16.h"], deps = [ ":numpy_lib", + ":safe_ptr", "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/python_runtime:headers", ], ) +tf_python_pybind_extension( + name = "_pywrap_bfloat16", + srcs = ["lib/core/bfloat16_wrapper.cc"], + hdrs = ["lib/core/bfloat16.h"], + module_name = "_pywrap_bfloat16", + deps = [ + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + cc_library( name = "ndarray_tensor_bridge", srcs = ["lib/core/ndarray_tensor_bridge.cc"], @@ -421,7 +435,7 @@ cc_library( ], ), deps = [ - ":ndarray_tensor_types", + ":bfloat16_lib", ":numpy_lib", "//tensorflow/c:c_api", "//tensorflow/core:lib", @@ -782,31 +796,6 @@ cc_library( ], ) -cc_library( - name = "ndarray_tensor_types", - srcs = ["lib/core/ndarray_tensor_types.cc"], - hdrs = ["lib/core/ndarray_tensor_types.h"], - deps = [ - ":bfloat16_lib", - ":numpy_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//third_party/python_runtime:headers", - "@com_google_absl//absl/container:flat_hash_set", - ], -) - -cc_library( - name = "ndarray_tensor_types_headers_lib", - hdrs = ["lib/core/ndarray_tensor_types.h"], - deps = [ - ":numpy_lib", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "ndarray_tensor", srcs = ["lib/core/ndarray_tensor.cc"], @@ -815,8 +804,8 @@ cc_library( "//learning/deepmind/courier:__subpackages__", ]), deps = [ + ":bfloat16_lib", ":ndarray_tensor_bridge", - ":ndarray_tensor_types", ":numpy_lib", ":safe_ptr", "//tensorflow/c:c_api", @@ -1176,7 +1165,6 @@ tf_python_pybind_extension( srcs = ["framework/dtypes.cc"], module_name = "_dtypes", deps = [ - ":ndarray_tensor_types_headers_lib", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:protos_all_cc", "//third_party/eigen3", @@ -1190,6 +1178,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":_dtypes", + ":_pywrap_bfloat16", ":pywrap_tensorflow", "//tensorflow/core:protos_all_py", ], @@ -5518,6 +5507,7 @@ tf_py_wrap_cc( "//conditions:default": None, }), deps = [ + ":bfloat16_lib", ":cost_analyzer_lib", ":model_analyzer_lib", ":cpp_python_util", @@ -5587,8 +5577,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [ ":numpy_lib", # checkpoint_reader ":safe_ptr", # checkpoint_reader ":python_op_gen", # python_op_gen - ":bfloat16_lib", # _dtypes - ":ndarray_tensor_types", # _dtypes + ":bfloat16_lib", # bfloat16 "//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib "//tensorflow/core/util/tensor_bundle", # checkpoint_reader "//tensorflow/core/common_runtime/eager:eager_executor", # tfe @@ -6215,6 +6204,7 @@ cuda_py_test( ":client_testlib", ":constant_op", ":dtypes", + ":framework_for_generated_wrappers", ":framework_ops", ":training", ":variable_scope", diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index 4613d762030..b5dc355bf5f 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -156,6 +156,7 @@ py_test( deps = [ ":convert", "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:util", ], ) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index df1a4095932..ad792ab70ba 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -44,7 +44,6 @@ cc_library( "//tensorflow/python:cpp_python_util", "//tensorflow/python:ndarray_tensor", "//tensorflow/python:ndarray_tensor_bridge", - "//tensorflow/python:ndarray_tensor_types", "//tensorflow/python:numpy_lib", "//tensorflow/python:py_seq_tensor", "//tensorflow/python:safe_ptr", diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 66c2b8573f3..519026f6456 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/python/eager/pywrap_tfe.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" -#include "tensorflow/python/lib/core/ndarray_tensor_types.h" #include "tensorflow/python/lib/core/numpy.h" #include "tensorflow/python/lib/core/py_seq_tensor.h" #include "tensorflow/python/lib/core/safe_ptr.h" @@ -289,15 +288,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, if (PyArray_Check(value)) { int desired_np_dtype = -1; if (dtype != tensorflow::DT_INVALID) { - PyArray_Descr* descr = nullptr; - if (!tensorflow::DataTypeToPyArray_Descr(dtype, &descr).ok()) { + if (!tensorflow::TF_DataType_to_PyArray_TYPE( + static_cast(dtype), &desired_np_dtype) + .ok()) { PyErr_SetString( PyExc_TypeError, tensorflow::strings::StrCat("Invalid dtype argument value ", dtype) .c_str()); return nullptr; } - desired_np_dtype = descr->type_num; } PyArrayObject* array = reinterpret_cast(value); int current_np_dtype = PyArray_TYPE(array); diff --git a/tensorflow/python/framework/dtypes.cc b/tensorflow/python/framework/dtypes.cc index c5efd68ef70..7c8521bd2d0 100644 --- a/tensorflow/python/framework/dtypes.cc +++ b/tensorflow/python/framework/dtypes.cc @@ -17,7 +17,6 @@ limitations under the License. #include "include/pybind11/pybind11.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/python/lib/core/ndarray_tensor_types.h" namespace { @@ -61,18 +60,6 @@ inline bool DataTypeIsNumPyCompatible(DataType dt) { namespace py = pybind11; PYBIND11_MODULE(_dtypes, m) { - tensorflow::MaybeRegisterCustomNumPyTypes(); - - m.attr("np_bfloat16") = - reinterpret_cast(tensorflow::BFLOAT16_DESCR); - m.attr("np_qint8") = reinterpret_cast(tensorflow::QINT8_DESCR); - m.attr("np_qint16") = reinterpret_cast(tensorflow::QINT16_DESCR); - m.attr("np_qint32") = reinterpret_cast(tensorflow::QINT32_DESCR); - m.attr("np_quint8") = reinterpret_cast(tensorflow::QUINT8_DESCR); - m.attr("np_quint16") = reinterpret_cast(tensorflow::QUINT16_DESCR); - m.attr("np_resource") = - reinterpret_cast(tensorflow::RESOURCE_DESCR); - py::class_(m, "DType") .def(py::init([](py::object obj) { auto id = static_cast(py::int_(obj)); diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 405184bd99c..44d98a9f73c 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -20,16 +20,17 @@ from __future__ import print_function import numpy as np from six.moves import builtins -# TODO(b/143110113): This import has to come first. This is a temporary -# workaround which fixes repeated proto registration on macOS. -# pylint: disable=g-bad-import-order, unused-import -from tensorflow.python import pywrap_tensorflow -# pylint: enable=g-bad-import-order, unused-import - from tensorflow.core.framework import types_pb2 +# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid +# protobuf errors where a file is defined twice on MacOS. +# pylint: disable=invalid-import-order,g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +from tensorflow.python import _pywrap_bfloat16 from tensorflow.python import _dtypes from tensorflow.python.util.tf_export import tf_export +_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() + # pylint: disable=slots-on-old-class @tf_export("dtypes.DType", "DType") @@ -424,18 +425,20 @@ _STRING_TO_TF["double_ref"] = float64_ref # Numpy representation for quantized dtypes. # -_np_qint8 = _dtypes.np_qint8 -_np_qint16 = _dtypes.np_qint16 -_np_qint32 = _dtypes.np_qint32 -_np_quint8 = _dtypes.np_quint8 -_np_quint16 = _dtypes.np_quint16 +# These are magic strings that are used in the swig wrapper to identify +# quantized types. +# TODO(mrry,keveman): Investigate Numpy type registration to replace this +# hard-coding of names. +_np_qint8 = np.dtype([("qint8", np.int8)]) +_np_quint8 = np.dtype([("quint8", np.uint8)]) +_np_qint16 = np.dtype([("qint16", np.int16)]) +_np_quint16 = np.dtype([("quint16", np.uint16)]) +_np_qint32 = np.dtype([("qint32", np.int32)]) -# Technically, _np_bfloat does not have to be a Python class, but existing -# code expects it to. -_np_bfloat16 = _dtypes.np_bfloat16.type +# _np_bfloat16 is defined by a module import. # Custom struct dtype for directly-fed ResourceHandles of supported type(s). -np_resource = _dtypes.np_resource +np_resource = np.dtype([("resource", np.ubyte)]) # Standard mappings between types_pb2.DataType values and numpy.dtypes. _NP_TO_TF = { diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index c1e09e3a384..5383410f999 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -321,7 +321,7 @@ class PyFuncTest(PyFuncTestBase): y, = script_ops.py_func(bad, [], [dtypes.float32]) with self.assertRaisesRegexp(errors.InternalError, - "Unsupported NumPy struct data type"): + "Unsupported numpy data type"): self.evaluate(y) @test_util.run_v1_only("b/120545219") diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 7fda88eaaf8..42b248a7ddb 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -21,19 +21,11 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/python/lib/core/numpy.h" +#include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { namespace { -struct PyDecrefDeleter { - void operator()(PyObject* p) const { Py_DECREF(p); } -}; - -using Safe_PyObjectPtr = std::unique_ptr; -Safe_PyObjectPtr make_safe(PyObject* object) { - return Safe_PyObjectPtr(object); -} - // Workarounds for Python 2 vs 3 API differences. #if PY_MAJOR_VERSION < 3 diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index ee17fea6e02..32453ae2296 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -24,10 +24,12 @@ import math import numpy as np # pylint: disable=unused-import,g-bad-import-order +from tensorflow.python import _pywrap_bfloat16 from tensorflow.python.framework import dtypes from tensorflow.python.platform import test -bfloat16 = dtypes._np_bfloat16 # pylint: disable=protected-access + +bfloat16 = _pywrap_bfloat16.TF_bfloat16_type() class Bfloat16Test(test.TestCase): diff --git a/tensorflow/python/lib/core/bfloat16_wrapper.cc b/tensorflow/python/lib/core/bfloat16_wrapper.cc new file mode 100644 index 00000000000..4a8e180c154 --- /dev/null +++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc @@ -0,0 +1,24 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/pybind11/pybind11.h" +#include "tensorflow/python/lib/core/bfloat16.h" + +PYBIND11_MODULE(_pywrap_bfloat16, m) { + tensorflow::RegisterNumpyBfloat16(); + + m.def("TF_bfloat16_type", + [] { return pybind11::handle(tensorflow::Bfloat16PyType()); }); +} diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index fcf41c2de6e..8c8362972be 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -21,13 +21,171 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/python/lib/core/bfloat16.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" -#include "tensorflow/python/lib/core/ndarray_tensor_types.h" #include "tensorflow/python/lib/core/numpy.h" namespace tensorflow { namespace { +char const* numpy_type_name(int numpy_type) { + switch (numpy_type) { +#define TYPE_CASE(s) \ + case s: \ + return #s + + TYPE_CASE(NPY_BOOL); + TYPE_CASE(NPY_BYTE); + TYPE_CASE(NPY_UBYTE); + TYPE_CASE(NPY_SHORT); + TYPE_CASE(NPY_USHORT); + TYPE_CASE(NPY_INT); + TYPE_CASE(NPY_UINT); + TYPE_CASE(NPY_LONG); + TYPE_CASE(NPY_ULONG); + TYPE_CASE(NPY_LONGLONG); + TYPE_CASE(NPY_ULONGLONG); + TYPE_CASE(NPY_FLOAT); + TYPE_CASE(NPY_DOUBLE); + TYPE_CASE(NPY_LONGDOUBLE); + TYPE_CASE(NPY_CFLOAT); + TYPE_CASE(NPY_CDOUBLE); + TYPE_CASE(NPY_CLONGDOUBLE); + TYPE_CASE(NPY_OBJECT); + TYPE_CASE(NPY_STRING); + TYPE_CASE(NPY_UNICODE); + TYPE_CASE(NPY_VOID); + TYPE_CASE(NPY_DATETIME); + TYPE_CASE(NPY_TIMEDELTA); + TYPE_CASE(NPY_HALF); + TYPE_CASE(NPY_NTYPES); + TYPE_CASE(NPY_NOTYPE); + TYPE_CASE(NPY_CHAR); + TYPE_CASE(NPY_USERDEF); + default: + return "not a numpy type"; + } +} + +Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr, + TF_DataType* out_tf_datatype) { + PyObject* key; + PyObject* value; + Py_ssize_t pos = 0; + if (PyDict_Next(descr->fields, &pos, &key, &value)) { + // In Python 3, the keys of numpy custom struct types are unicode, unlike + // Python 2, where the keys are bytes. + const char* key_string = + PyBytes_Check(key) ? PyBytes_AsString(key) + : PyBytes_AsString(PyUnicode_AsASCIIString(key)); + if (!key_string) { + return errors::Internal("Corrupt numpy type descriptor"); + } + tensorflow::string key = key_string; + // The typenames here should match the field names in the custom struct + // types constructed in test_util.py. + // TODO(mrry,keveman): Investigate Numpy type registration to replace this + // hard-coding of names. + if (key == "quint8") { + *out_tf_datatype = TF_QUINT8; + } else if (key == "qint8") { + *out_tf_datatype = TF_QINT8; + } else if (key == "qint16") { + *out_tf_datatype = TF_QINT16; + } else if (key == "quint16") { + *out_tf_datatype = TF_QUINT16; + } else if (key == "qint32") { + *out_tf_datatype = TF_QINT32; + } else if (key == "resource") { + *out_tf_datatype = TF_RESOURCE; + } else { + return errors::Internal("Unsupported numpy data type"); + } + return Status::OK(); + } + return errors::Internal("Unsupported numpy data type"); +} + +Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array, + TF_DataType* out_tf_datatype) { + int pyarray_type = PyArray_TYPE(array); + PyArray_Descr* descr = PyArray_DESCR(array); + switch (pyarray_type) { + case NPY_FLOAT16: + *out_tf_datatype = TF_HALF; + break; + case NPY_FLOAT32: + *out_tf_datatype = TF_FLOAT; + break; + case NPY_FLOAT64: + *out_tf_datatype = TF_DOUBLE; + break; + case NPY_INT32: + *out_tf_datatype = TF_INT32; + break; + case NPY_UINT8: + *out_tf_datatype = TF_UINT8; + break; + case NPY_UINT16: + *out_tf_datatype = TF_UINT16; + break; + case NPY_UINT32: + *out_tf_datatype = TF_UINT32; + break; + case NPY_UINT64: + *out_tf_datatype = TF_UINT64; + break; + case NPY_INT8: + *out_tf_datatype = TF_INT8; + break; + case NPY_INT16: + *out_tf_datatype = TF_INT16; + break; + case NPY_INT64: + *out_tf_datatype = TF_INT64; + break; + case NPY_BOOL: + *out_tf_datatype = TF_BOOL; + break; + case NPY_COMPLEX64: + *out_tf_datatype = TF_COMPLEX64; + break; + case NPY_COMPLEX128: + *out_tf_datatype = TF_COMPLEX128; + break; + case NPY_OBJECT: + case NPY_STRING: + case NPY_UNICODE: + *out_tf_datatype = TF_STRING; + break; + case NPY_VOID: + // Quantized types are currently represented as custom struct types. + // PyArray_TYPE returns NPY_VOID for structs, and we should look into + // descr to derive the actual type. + // Direct feeds of certain types of ResourceHandles are represented as a + // custom struct type. + return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype); + default: + if (pyarray_type == Bfloat16NumpyType()) { + *out_tf_datatype = TF_BFLOAT16; + break; + } else if (pyarray_type == NPY_ULONGLONG) { + // NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values + // might be different on certain platforms. + *out_tf_datatype = TF_UINT64; + break; + } else if (pyarray_type == NPY_LONGLONG) { + // NPY_LONGLONG is equivalent to NPY_INT64, while their enum values + // might be different on certain platforms. + *out_tf_datatype = TF_INT64; + break; + } + return errors::Internal("Unsupported numpy type: ", + numpy_type_name(pyarray_type)); + } + return Status::OK(); +} + Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len, PyObject** ptr_owner) { *ptr_owner = nullptr; @@ -186,6 +344,38 @@ Status GetPyArrayDimensionsForTensor(const TF_Tensor* tensor, return Status::OK(); } +// Determine the type description (PyArray_Descr) of a numpy ndarray to be +// created to represent an output Tensor. +Status GetPyArrayDescrForTensor(const TF_Tensor* tensor, + PyArray_Descr** descr) { + if (TF_TensorType(tensor) == TF_RESOURCE) { + PyObject* field = PyTuple_New(3); +#if PY_MAJOR_VERSION < 3 + PyTuple_SetItem(field, 0, PyBytes_FromString("resource")); +#else + PyTuple_SetItem(field, 0, PyUnicode_FromString("resource")); +#endif + PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE)); + PyTuple_SetItem(field, 2, PyLong_FromLong(1)); + PyObject* fields = PyList_New(1); + PyList_SetItem(fields, 0, field); + int convert_result = PyArray_DescrConverter(fields, descr); + Py_CLEAR(field); + Py_CLEAR(fields); + if (convert_result != 1) { + return errors::Internal("Failed to create numpy array description for ", + "TF_RESOURCE-type tensor"); + } + } else { + int type_num = -1; + TF_RETURN_IF_ERROR( + TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num)); + *descr = PyArray_DescrFromType(type_num); + } + + return Status::OK(); +} + inline void FastMemcpy(void* dst, const void* src, size_t size) { // clang-format off switch (size) { @@ -271,8 +461,7 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) { // Copy the TF_TensorData into a newly-created ndarray and return it. PyArray_Descr* descr = nullptr; - TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr( - static_cast(TF_TensorType(tensor.get())), &descr)); + TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr)); Safe_PyObjectPtr safe_out_array = tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0)); if (!safe_out_array) { @@ -310,11 +499,7 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) { // Convert numpy dtype to TensorFlow dtype. TF_DataType dtype = TF_FLOAT; - { - DataType tmp; - TF_RETURN_IF_ERROR(PyArray_DescrToDataType(PyArray_DESCR(array), &tmp)); - dtype = static_cast(tmp); - } + TF_RETURN_IF_ERROR(PyArray_TYPE_to_TF_DataType(array, &dtype)); tensorflow::int64 nelems = 1; gtl::InlinedVector dims; diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc index 485b34cd539..03ff77100d2 100644 --- a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc +++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc @@ -13,19 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" +// Must be included first. +#include "tensorflow/python/lib/core/numpy.h" #include -// Must be included first. -// clang-format: off -#include "tensorflow/python/lib/core/numpy.h" -// clang-format: on - #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/python/lib/core/ndarray_tensor_types.h" +#include "tensorflow/python/lib/core/bfloat16.h" +#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" namespace tensorflow { @@ -110,6 +107,85 @@ PyTypeObject TensorReleaserType = { nullptr, /* tp_richcompare */ }; +Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype, + int* out_pyarray_type) { + switch (tf_datatype) { + case TF_HALF: + *out_pyarray_type = NPY_FLOAT16; + break; + case TF_FLOAT: + *out_pyarray_type = NPY_FLOAT32; + break; + case TF_DOUBLE: + *out_pyarray_type = NPY_FLOAT64; + break; + case TF_INT32: + *out_pyarray_type = NPY_INT32; + break; + case TF_UINT32: + *out_pyarray_type = NPY_UINT32; + break; + case TF_UINT8: + *out_pyarray_type = NPY_UINT8; + break; + case TF_UINT16: + *out_pyarray_type = NPY_UINT16; + break; + case TF_INT8: + *out_pyarray_type = NPY_INT8; + break; + case TF_INT16: + *out_pyarray_type = NPY_INT16; + break; + case TF_INT64: + *out_pyarray_type = NPY_INT64; + break; + case TF_UINT64: + *out_pyarray_type = NPY_UINT64; + break; + case TF_BOOL: + *out_pyarray_type = NPY_BOOL; + break; + case TF_COMPLEX64: + *out_pyarray_type = NPY_COMPLEX64; + break; + case TF_COMPLEX128: + *out_pyarray_type = NPY_COMPLEX128; + break; + case TF_STRING: + *out_pyarray_type = NPY_OBJECT; + break; + case TF_RESOURCE: + *out_pyarray_type = NPY_VOID; + break; + // TODO(keveman): These should be changed to NPY_VOID, and the type used for + // the resulting numpy array should be the custom struct types that we + // expect for quantized types. + case TF_QINT8: + *out_pyarray_type = NPY_INT8; + break; + case TF_QUINT8: + *out_pyarray_type = NPY_UINT8; + break; + case TF_QINT16: + *out_pyarray_type = NPY_INT16; + break; + case TF_QUINT16: + *out_pyarray_type = NPY_UINT16; + break; + case TF_QINT32: + *out_pyarray_type = NPY_INT32; + break; + case TF_BFLOAT16: + *out_pyarray_type = Bfloat16NumpyType(); + break; + default: + return errors::Internal("Tensorflow type ", tf_datatype, + " not convertible to numpy dtype."); + } + return Status::OK(); +} + Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype, std::function destructor, PyObject** result) { if (dtype == DT_STRING || dtype == DT_RESOURCE) { @@ -117,11 +193,15 @@ Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype, "Cannot convert string or resource Tensors."); } - PyArray_Descr* descr = nullptr; - TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(dtype, &descr)); + int type_num = -1; + Status s = + TF_DataType_to_PyArray_TYPE(static_cast(dtype), &type_num); + if (!s.ok()) { + return s; + } + auto* np_array = reinterpret_cast( - PyArray_SimpleNewFromData(dim_size, dims, descr->type_num, data)); - CHECK_NE(np_array, nullptr); + PyArray_SimpleNewFromData(dim_size, dims, type_num, data)); PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA); if (PyType_Ready(&TensorReleaserType) == -1) { return errors::Unknown("Python type initialization failed."); diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.h b/tensorflow/python/lib/core/ndarray_tensor_bridge.h index d6943af8ed9..029c0d3ef0a 100644 --- a/tensorflow/python/lib/core/ndarray_tensor_bridge.h +++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.h @@ -42,6 +42,10 @@ void ClearDecrefCache(); Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype, std::function destructor, PyObject** result); +// Converts TF_DataType to the corresponding numpy type. +Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype, + int* out_pyarray_type); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_BRIDGE_H_ diff --git a/tensorflow/python/lib/core/ndarray_tensor_types.cc b/tensorflow/python/lib/core/ndarray_tensor_types.cc deleted file mode 100644 index c255db4dd70..00000000000 --- a/tensorflow/python/lib/core/ndarray_tensor_types.cc +++ /dev/null @@ -1,287 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/python/lib/core/ndarray_tensor_types.h" - -#include - -// Must be included first. -// clang-format: off -#include "tensorflow/python/lib/core/numpy.h" -// clang-format: on - -#include "absl/container/flat_hash_set.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/python/lib/core/bfloat16.h" - -namespace tensorflow { - -PyArray_Descr* BFLOAT16_DESCR = nullptr; -PyArray_Descr* QINT8_DESCR = nullptr; -PyArray_Descr* QINT16_DESCR = nullptr; -PyArray_Descr* QINT32_DESCR = nullptr; -PyArray_Descr* QUINT8_DESCR = nullptr; -PyArray_Descr* QUINT16_DESCR = nullptr; -PyArray_Descr* RESOURCE_DESCR = nullptr; - -// Define a struct array data type `[(tag, type_num)]`. -PyArray_Descr* DefineStructTypeAlias(const char* tag, int type_num) { -#if PY_MAJOR_VERSION < 3 - auto* py_tag = PyBytes_FromString(tag); -#else - auto* py_tag = PyUnicode_FromString(tag); -#endif - auto* descr = PyArray_DescrFromType(type_num); - auto* py_tag_and_descr = PyTuple_Pack(2, py_tag, descr); - auto* obj = PyList_New(1); - PyList_SetItem(obj, 0, py_tag_and_descr); - PyArray_Descr* alias_descr = nullptr; - // TODO(slebedev): Switch to PyArray_DescrNewFromType because struct - // array dtypes could not be used for scalars. Note that this will - // require registering type conversions and UFunc specializations. - // See b/144230631. - CHECK_EQ(PyArray_DescrConverter(obj, &alias_descr), NPY_SUCCEED); - Py_DECREF(obj); - Py_DECREF(py_tag_and_descr); - Py_DECREF(py_tag); - Py_DECREF(descr); - CHECK_NE(alias_descr, nullptr); - return alias_descr; -} - -void MaybeRegisterCustomNumPyTypes() { - static bool registered = false; - if (registered) return; - ImportNumpy(); // Ensure NumPy is loaded. - // Make sure the tags are consistent with DataTypeToPyArray_Descr. - QINT8_DESCR = DefineStructTypeAlias("qint8", NPY_INT8); - QINT16_DESCR = DefineStructTypeAlias("qint16", NPY_INT16); - QINT32_DESCR = DefineStructTypeAlias("qint32", NPY_INT32); - QUINT8_DESCR = DefineStructTypeAlias("quint8", NPY_UINT8); - QUINT16_DESCR = DefineStructTypeAlias("quint16", NPY_UINT16); - RESOURCE_DESCR = DefineStructTypeAlias("resource", NPY_UBYTE); - RegisterNumpyBfloat16(); - BFLOAT16_DESCR = PyArray_DescrFromType(Bfloat16NumpyType()); - registered = true; -} - -const char* PyArray_DescrReprAsString(PyArray_Descr* descr) { - auto* descr_repr = PyObject_Repr(reinterpret_cast(descr)); - const char* result; -#if PY_MAJOR_VERSION < 3 - result = PyBytes_AsString(descr_repr); -#else - auto* tmp = PyUnicode_AsASCIIString(descr_repr); - result = PyBytes_AsString(tmp); - Py_DECREF(tmp); -#endif - - Py_DECREF(descr_repr); - return result; -} - -Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr) { - switch (dt) { - case DT_HALF: - *out_descr = PyArray_DescrFromType(NPY_FLOAT16); - break; - case DT_FLOAT: - *out_descr = PyArray_DescrFromType(NPY_FLOAT32); - break; - case DT_DOUBLE: - *out_descr = PyArray_DescrFromType(NPY_FLOAT64); - break; - case DT_INT32: - *out_descr = PyArray_DescrFromType(NPY_INT32); - break; - case DT_UINT32: - *out_descr = PyArray_DescrFromType(NPY_UINT32); - break; - case DT_UINT8: - *out_descr = PyArray_DescrFromType(NPY_UINT8); - break; - case DT_UINT16: - *out_descr = PyArray_DescrFromType(NPY_UINT16); - break; - case DT_INT8: - *out_descr = PyArray_DescrFromType(NPY_INT8); - break; - case DT_INT16: - *out_descr = PyArray_DescrFromType(NPY_INT16); - break; - case DT_INT64: - *out_descr = PyArray_DescrFromType(NPY_INT64); - break; - case DT_UINT64: - *out_descr = PyArray_DescrFromType(NPY_UINT64); - break; - case DT_BOOL: - *out_descr = PyArray_DescrFromType(NPY_BOOL); - break; - case DT_COMPLEX64: - *out_descr = PyArray_DescrFromType(NPY_COMPLEX64); - break; - case DT_COMPLEX128: - *out_descr = PyArray_DescrFromType(NPY_COMPLEX128); - break; - case DT_STRING: - *out_descr = PyArray_DescrFromType(NPY_OBJECT); - break; - case DT_QINT8: - *out_descr = PyArray_DescrFromType(NPY_INT8); - break; - case DT_QINT16: - *out_descr = PyArray_DescrFromType(NPY_INT16); - break; - case DT_QINT32: - *out_descr = PyArray_DescrFromType(NPY_INT32); - break; - case DT_QUINT8: - *out_descr = PyArray_DescrFromType(NPY_UINT8); - break; - case DT_QUINT16: - *out_descr = PyArray_DescrFromType(NPY_UINT16); - break; - case DT_RESOURCE: - *out_descr = PyArray_DescrFromType(NPY_UBYTE); - break; - case DT_BFLOAT16: - Py_INCREF(BFLOAT16_DESCR); - *out_descr = BFLOAT16_DESCR; - break; - default: - return errors::Internal("TensorFlow data type ", DataType_Name(dt), - " cannot be converted to a NumPy data type."); - } - - return Status::OK(); -} - -// NumPy defines fixed-width aliases for platform integer types. However, -// some types do not have a fixed-width alias. Specifically -// -// * on a LLP64 system NPY_INT32 == NPY_LONG therefore NPY_INT is not aliased; -// * on a LP64 system NPY_INT64 == NPY_LONG and NPY_LONGLONG is not aliased. -// -int MaybeResolveNumPyPlatformType(int type_num) { - switch (type_num) { -#if NPY_BITS_OF_INT == 32 && NPY_BITS_OF_LONGLONG == 32 - case NPY_INT: - return NPY_INT32; - case NPY_UINT: - return NPY_UINT32; -#endif -#if NPY_BITSOF_INT == 32 && NPY_BITSOF_LONGLONG == 64 - case NPY_LONGLONG: - return NPY_INT64; - case NPY_ULONGLONG: - return NPY_UINT64; -#endif - default: - return type_num; - } -} - -Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt) { - const int type_num = MaybeResolveNumPyPlatformType(descr->type_num); - switch (type_num) { - case NPY_FLOAT16: - *out_dt = DT_HALF; - break; - case NPY_FLOAT32: - *out_dt = DT_FLOAT; - break; - case NPY_FLOAT64: - *out_dt = DT_DOUBLE; - break; - case NPY_INT8: - *out_dt = DT_INT8; - break; - case NPY_INT16: - *out_dt = DT_INT16; - break; - case NPY_INT32: - *out_dt = DT_INT32; - break; - case NPY_INT64: - *out_dt = DT_INT64; - break; - case NPY_UINT8: - *out_dt = DT_UINT8; - break; - case NPY_UINT16: - *out_dt = DT_UINT16; - break; - case NPY_UINT32: - *out_dt = DT_UINT32; - break; - case NPY_UINT64: - *out_dt = DT_UINT64; - break; - case NPY_BOOL: - *out_dt = DT_BOOL; - break; - case NPY_COMPLEX64: - *out_dt = DT_COMPLEX64; - break; - case NPY_COMPLEX128: - *out_dt = DT_COMPLEX128; - break; - case NPY_OBJECT: - case NPY_STRING: - case NPY_UNICODE: - *out_dt = DT_STRING; - break; - case NPY_VOID: { - if (descr == QINT8_DESCR) { - *out_dt = DT_QINT8; - break; - } else if (descr == QINT16_DESCR) { - *out_dt = DT_QINT16; - break; - } else if (descr == QINT32_DESCR) { - *out_dt = DT_QINT32; - break; - } else if (descr == QUINT8_DESCR) { - *out_dt = DT_QUINT8; - break; - } else if (descr == QUINT16_DESCR) { - *out_dt = DT_QUINT16; - break; - } else if (descr == RESOURCE_DESCR) { - *out_dt = DT_RESOURCE; - break; - } - - return errors::Internal("Unsupported NumPy struct data type: ", - PyArray_DescrReprAsString(descr)); - } - default: - if (type_num == Bfloat16NumpyType()) { - *out_dt = DT_BFLOAT16; - break; - } - - return errors::Internal("Unregistered NumPy data type: ", - PyArray_DescrReprAsString(descr)); - } - return Status::OK(); -} - -} // namespace tensorflow diff --git a/tensorflow/python/lib/core/ndarray_tensor_types.h b/tensorflow/python/lib/core/ndarray_tensor_types.h deleted file mode 100644 index 5a4a9050886..00000000000 --- a/tensorflow/python/lib/core/ndarray_tensor_types.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_ -#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_ - -// Must be included first. -// clang-format: off -#include "tensorflow/python/lib/core/numpy.h" -// clang-format: on - -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { - -extern PyArray_Descr* QINT8_DESCR; -extern PyArray_Descr* QINT16_DESCR; -extern PyArray_Descr* QINT32_DESCR; -extern PyArray_Descr* QUINT8_DESCR; -extern PyArray_Descr* QUINT16_DESCR; -extern PyArray_Descr* RESOURCE_DESCR; -extern PyArray_Descr* BFLOAT16_DESCR; - -// Register custom NumPy types. -// -// This function must be called in order to be able to map TensorFlow -// data types which do not have a corresponding standard NumPy data type -// (e.g. bfloat16 or qint8). -// -// TODO(b/144230631): The name is slightly misleading, as the function only -// registers bfloat16 and defines structured aliases for other data types -// (e.g. qint8). -void MaybeRegisterCustomNumPyTypes(); - -// Returns a NumPy data type matching a given tensorflow::DataType. If the -// function call succeeds, the caller is responsible for DECREF'ing the -// resulting PyArray_Descr*. -// -// NumPy does not support quantized integer types, so TensorFlow defines -// structured aliases for them, e.g. tf.qint8 is represented as -// np.dtype([("qint8", np.int8)]). However, for historical reasons this -// function does not use these aliases, and instead returns the *aliased* -// types (np.int8 in the example). -// TODO(b/144230631): Return an alias instead of the aliased type. -Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr); - -// Returns a tensorflow::DataType corresponding to a given NumPy data type. -Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt); - -} // namespace tensorflow - -#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_ diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl index 82640ea42d4..f894c000ddc 100644 --- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl +++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl @@ -175,7 +175,7 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file): else: # If not a section header and not an empty line, then it's a symbol # line. e.g. `tensorflow::swig::IsSequence` - symbols[curr_lib].append(re.escape(line)) + symbols[curr_lib].append(line) lib_paths = [] with open(lib_paths_file, "r") as f: diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 23a5b201e8c..e657edc4fbf 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -43,6 +43,10 @@ tensorflow::tfprof::SerializeToString [graph_analyzer_tool] # graph_analyzer tensorflow::grappler::graph_analyzer::GraphAnalyzerTool +[bfloat16_lib] # bfloat16 +tensorflow::RegisterNumpyBfloat16 +tensorflow::Bfloat16PyType + [events_writer] # events_writer tensorflow::EventsWriter::Init tensorflow::EventsWriter::InitWithSuffix @@ -185,13 +189,3 @@ tensorflow::Set_TF_Status_from_Status [context] # tfe tensorflow::EagerContext::WaitForAndCloseRemoteContexts - -[ndarray_tensor_types] # _dtypes -tensorflow::MaybeRegisterCustomNumPyTypes -tensorflow::BFLOAT16_DESCR -tensorflow::QINT8_DESCR -tensorflow::QINT16_DESCR -tensorflow::QINT32_DESCR -tensorflow::QUINT8_DESCR -tensorflow::QUINT16_DESCR -tensorflow::RESOURCE_DESCR