Moved custom NumPy data types to :ndarray_tensor_types

Prior to this change qint*, quint* and resource were defined in Python
as a single-field struct arrays, e.g.

_np_qint8 = np.dtype([("qint8", np.int8)])

Several TensorFlow functions had to special-case struct arrays (which
otherwise have type NPY_VOID) and infer the real dtype from struct fields.

Having these data types defined and handled in C++ allows to minimize magic
on the Python/C++ boundary.

Note 1: The defined data types are *not* registered because NumPy requires
every registered data type to also register casting functions and UFunc
specializations which will require significantly more code.

Note 2: Tensor->NumPy conversion for qint*, quint* and resource dtypes does
not use the defined types, so

tf.constant([42], dtype=tf.qint8).numpy()

returns an array of dtype tf.int8. This behavior is unaffected by this change.

PiperOrigin-RevId: 286392791
Change-Id: I0bdc55c2002195bca6273c94d3965d6620239985
This commit is contained in:
Sergei Lebedev 2019-12-19 07:58:37 -08:00 committed by TensorFlower Gardener
parent 69111e174c
commit 40dab8918d
17 changed files with 456 additions and 364 deletions

View File

@ -394,8 +394,6 @@ cc_library(
srcs = ["lib/core/numpy.cc"],
hdrs = ["lib/core/numpy.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
@ -407,24 +405,12 @@ cc_library(
hdrs = ["lib/core/bfloat16.h"],
deps = [
":numpy_lib",
":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/python_runtime:headers",
],
)
tf_python_pybind_extension(
name = "_pywrap_bfloat16",
srcs = ["lib/core/bfloat16_wrapper.cc"],
hdrs = ["lib/core/bfloat16.h"],
module_name = "_pywrap_bfloat16",
deps = [
"//third_party/python_runtime:headers",
"@pybind11",
],
)
cc_library(
name = "ndarray_tensor_bridge",
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
@ -435,7 +421,7 @@ cc_library(
],
),
deps = [
":bfloat16_lib",
":ndarray_tensor_types",
":numpy_lib",
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
@ -796,6 +782,31 @@ cc_library(
],
)
cc_library(
name = "ndarray_tensor_types",
srcs = ["lib/core/ndarray_tensor_types.cc"],
hdrs = ["lib/core/ndarray_tensor_types.h"],
deps = [
":bfloat16_lib",
":numpy_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/container:flat_hash_set",
],
)
cc_library(
name = "ndarray_tensor_types_headers_lib",
hdrs = ["lib/core/ndarray_tensor_types.h"],
deps = [
":numpy_lib",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "ndarray_tensor",
srcs = ["lib/core/ndarray_tensor.cc"],
@ -804,8 +815,8 @@ cc_library(
"//learning/deepmind/courier:__subpackages__",
]),
deps = [
":bfloat16_lib",
":ndarray_tensor_bridge",
":ndarray_tensor_types",
":numpy_lib",
":safe_ptr",
"//tensorflow/c:c_api",
@ -1165,6 +1176,7 @@ tf_python_pybind_extension(
srcs = ["framework/dtypes.cc"],
module_name = "_dtypes",
deps = [
":ndarray_tensor_types_headers_lib",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
@ -1178,7 +1190,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":_dtypes",
":_pywrap_bfloat16",
":pywrap_tensorflow",
"//tensorflow/core:protos_all_py",
],
@ -5507,7 +5518,6 @@ tf_py_wrap_cc(
"//conditions:default": None,
}),
deps = [
":bfloat16_lib",
":cost_analyzer_lib",
":model_analyzer_lib",
":cpp_python_util",
@ -5577,7 +5587,8 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
":numpy_lib", # checkpoint_reader
":safe_ptr", # checkpoint_reader
":python_op_gen", # python_op_gen
":bfloat16_lib", # bfloat16
":bfloat16_lib", # _dtypes
":ndarray_tensor_types", # _dtypes
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
@ -6204,7 +6215,6 @@ cuda_py_test(
":client_testlib",
":constant_op",
":dtypes",
":framework_for_generated_wrappers",
":framework_ops",
":training",
":variable_scope",

View File

@ -156,7 +156,6 @@ py_test(
deps = [
":convert",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:util",
],
)

View File

@ -44,6 +44,7 @@ cc_library(
"//tensorflow/python:cpp_python_util",
"//tensorflow/python:ndarray_tensor",
"//tensorflow/python:ndarray_tensor_bridge",
"//tensorflow/python:ndarray_tensor_types",
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_seq_tensor",
"//tensorflow/python:safe_ptr",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
@ -288,15 +289,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
if (PyArray_Check(value)) {
int desired_np_dtype = -1;
if (dtype != tensorflow::DT_INVALID) {
if (!tensorflow::TF_DataType_to_PyArray_TYPE(
static_cast<TF_DataType>(dtype), &desired_np_dtype)
.ok()) {
PyArray_Descr* descr = nullptr;
if (!tensorflow::DataTypeToPyArray_Descr(dtype, &descr).ok()) {
PyErr_SetString(
PyExc_TypeError,
tensorflow::strings::StrCat("Invalid dtype argument value ", dtype)
.c_str());
return nullptr;
}
desired_np_dtype = descr->type_num;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
int current_np_dtype = PyArray_TYPE(array);

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
namespace {
@ -60,6 +61,18 @@ inline bool DataTypeIsNumPyCompatible(DataType dt) {
namespace py = pybind11;
PYBIND11_MODULE(_dtypes, m) {
tensorflow::MaybeRegisterCustomNumPyTypes();
m.attr("np_bfloat16") =
reinterpret_cast<PyObject*>(tensorflow::BFLOAT16_DESCR);
m.attr("np_qint8") = reinterpret_cast<PyObject*>(tensorflow::QINT8_DESCR);
m.attr("np_qint16") = reinterpret_cast<PyObject*>(tensorflow::QINT16_DESCR);
m.attr("np_qint32") = reinterpret_cast<PyObject*>(tensorflow::QINT32_DESCR);
m.attr("np_quint8") = reinterpret_cast<PyObject*>(tensorflow::QUINT8_DESCR);
m.attr("np_quint16") = reinterpret_cast<PyObject*>(tensorflow::QUINT16_DESCR);
m.attr("np_resource") =
reinterpret_cast<PyObject*>(tensorflow::RESOURCE_DESCR);
py::class_<tensorflow::DataType>(m, "DType")
.def(py::init([](py::object obj) {
auto id = static_cast<int>(py::int_(obj));

View File

@ -20,17 +20,16 @@ from __future__ import print_function
import numpy as np
from six.moves import builtins
# TODO(b/143110113): This import has to come first. This is a temporary
# workaround which fixes repeated proto registration on macOS.
# pylint: disable=g-bad-import-order, unused-import
from tensorflow.python import pywrap_tensorflow
# pylint: enable=g-bad-import-order, unused-import
from tensorflow.core.framework import types_pb2
# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
# protobuf errors where a file is defined twice on MacOS.
# pylint: disable=invalid-import-order,g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python import _pywrap_bfloat16
from tensorflow.python import _dtypes
from tensorflow.python.util.tf_export import tf_export
_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
# pylint: disable=slots-on-old-class
@tf_export("dtypes.DType", "DType")
@ -425,20 +424,18 @@ _STRING_TO_TF["double_ref"] = float64_ref
# Numpy representation for quantized dtypes.
#
# These are magic strings that are used in the swig wrapper to identify
# quantized types.
# TODO(mrry,keveman): Investigate Numpy type registration to replace this
# hard-coding of names.
_np_qint8 = np.dtype([("qint8", np.int8)])
_np_quint8 = np.dtype([("quint8", np.uint8)])
_np_qint16 = np.dtype([("qint16", np.int16)])
_np_quint16 = np.dtype([("quint16", np.uint16)])
_np_qint32 = np.dtype([("qint32", np.int32)])
_np_qint8 = _dtypes.np_qint8
_np_qint16 = _dtypes.np_qint16
_np_qint32 = _dtypes.np_qint32
_np_quint8 = _dtypes.np_quint8
_np_quint16 = _dtypes.np_quint16
# _np_bfloat16 is defined by a module import.
# Technically, _np_bfloat does not have to be a Python class, but existing
# code expects it to.
_np_bfloat16 = _dtypes.np_bfloat16.type
# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
np_resource = np.dtype([("resource", np.ubyte)])
np_resource = _dtypes.np_resource
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
_NP_TO_TF = {

View File

@ -321,7 +321,7 @@ class PyFuncTest(PyFuncTestBase):
y, = script_ops.py_func(bad, [], [dtypes.float32])
with self.assertRaisesRegexp(errors.InternalError,
"Unsupported numpy data type"):
"Unsupported NumPy struct data type"):
self.evaluate(y)
@test_util.run_v1_only("b/120545219")

View File

@ -21,11 +21,19 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
namespace {
struct PyDecrefDeleter {
void operator()(PyObject* p) const { Py_DECREF(p); }
};
using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
Safe_PyObjectPtr make_safe(PyObject* object) {
return Safe_PyObjectPtr(object);
}
// Workarounds for Python 2 vs 3 API differences.
#if PY_MAJOR_VERSION < 3

View File

@ -24,12 +24,10 @@ import math
import numpy as np
# pylint: disable=unused-import,g-bad-import-order
from tensorflow.python import _pywrap_bfloat16
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
bfloat16 = dtypes._np_bfloat16 # pylint: disable=protected-access
class Bfloat16Test(test.TestCase):

View File

@ -1,24 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/pybind11.h"
#include "tensorflow/python/lib/core/bfloat16.h"
PYBIND11_MODULE(_pywrap_bfloat16, m) {
tensorflow::RegisterNumpyBfloat16();
m.def("TF_bfloat16_type",
[] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
}

View File

@ -21,171 +21,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
#include "tensorflow/python/lib/core/numpy.h"
namespace tensorflow {
namespace {
char const* numpy_type_name(int numpy_type) {
switch (numpy_type) {
#define TYPE_CASE(s) \
case s: \
return #s
TYPE_CASE(NPY_BOOL);
TYPE_CASE(NPY_BYTE);
TYPE_CASE(NPY_UBYTE);
TYPE_CASE(NPY_SHORT);
TYPE_CASE(NPY_USHORT);
TYPE_CASE(NPY_INT);
TYPE_CASE(NPY_UINT);
TYPE_CASE(NPY_LONG);
TYPE_CASE(NPY_ULONG);
TYPE_CASE(NPY_LONGLONG);
TYPE_CASE(NPY_ULONGLONG);
TYPE_CASE(NPY_FLOAT);
TYPE_CASE(NPY_DOUBLE);
TYPE_CASE(NPY_LONGDOUBLE);
TYPE_CASE(NPY_CFLOAT);
TYPE_CASE(NPY_CDOUBLE);
TYPE_CASE(NPY_CLONGDOUBLE);
TYPE_CASE(NPY_OBJECT);
TYPE_CASE(NPY_STRING);
TYPE_CASE(NPY_UNICODE);
TYPE_CASE(NPY_VOID);
TYPE_CASE(NPY_DATETIME);
TYPE_CASE(NPY_TIMEDELTA);
TYPE_CASE(NPY_HALF);
TYPE_CASE(NPY_NTYPES);
TYPE_CASE(NPY_NOTYPE);
TYPE_CASE(NPY_CHAR);
TYPE_CASE(NPY_USERDEF);
default:
return "not a numpy type";
}
}
Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
TF_DataType* out_tf_datatype) {
PyObject* key;
PyObject* value;
Py_ssize_t pos = 0;
if (PyDict_Next(descr->fields, &pos, &key, &value)) {
// In Python 3, the keys of numpy custom struct types are unicode, unlike
// Python 2, where the keys are bytes.
const char* key_string =
PyBytes_Check(key) ? PyBytes_AsString(key)
: PyBytes_AsString(PyUnicode_AsASCIIString(key));
if (!key_string) {
return errors::Internal("Corrupt numpy type descriptor");
}
tensorflow::string key = key_string;
// The typenames here should match the field names in the custom struct
// types constructed in test_util.py.
// TODO(mrry,keveman): Investigate Numpy type registration to replace this
// hard-coding of names.
if (key == "quint8") {
*out_tf_datatype = TF_QUINT8;
} else if (key == "qint8") {
*out_tf_datatype = TF_QINT8;
} else if (key == "qint16") {
*out_tf_datatype = TF_QINT16;
} else if (key == "quint16") {
*out_tf_datatype = TF_QUINT16;
} else if (key == "qint32") {
*out_tf_datatype = TF_QINT32;
} else if (key == "resource") {
*out_tf_datatype = TF_RESOURCE;
} else {
return errors::Internal("Unsupported numpy data type");
}
return Status::OK();
}
return errors::Internal("Unsupported numpy data type");
}
Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
TF_DataType* out_tf_datatype) {
int pyarray_type = PyArray_TYPE(array);
PyArray_Descr* descr = PyArray_DESCR(array);
switch (pyarray_type) {
case NPY_FLOAT16:
*out_tf_datatype = TF_HALF;
break;
case NPY_FLOAT32:
*out_tf_datatype = TF_FLOAT;
break;
case NPY_FLOAT64:
*out_tf_datatype = TF_DOUBLE;
break;
case NPY_INT32:
*out_tf_datatype = TF_INT32;
break;
case NPY_UINT8:
*out_tf_datatype = TF_UINT8;
break;
case NPY_UINT16:
*out_tf_datatype = TF_UINT16;
break;
case NPY_UINT32:
*out_tf_datatype = TF_UINT32;
break;
case NPY_UINT64:
*out_tf_datatype = TF_UINT64;
break;
case NPY_INT8:
*out_tf_datatype = TF_INT8;
break;
case NPY_INT16:
*out_tf_datatype = TF_INT16;
break;
case NPY_INT64:
*out_tf_datatype = TF_INT64;
break;
case NPY_BOOL:
*out_tf_datatype = TF_BOOL;
break;
case NPY_COMPLEX64:
*out_tf_datatype = TF_COMPLEX64;
break;
case NPY_COMPLEX128:
*out_tf_datatype = TF_COMPLEX128;
break;
case NPY_OBJECT:
case NPY_STRING:
case NPY_UNICODE:
*out_tf_datatype = TF_STRING;
break;
case NPY_VOID:
// Quantized types are currently represented as custom struct types.
// PyArray_TYPE returns NPY_VOID for structs, and we should look into
// descr to derive the actual type.
// Direct feeds of certain types of ResourceHandles are represented as a
// custom struct type.
return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
default:
if (pyarray_type == Bfloat16NumpyType()) {
*out_tf_datatype = TF_BFLOAT16;
break;
} else if (pyarray_type == NPY_ULONGLONG) {
// NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values
// might be different on certain platforms.
*out_tf_datatype = TF_UINT64;
break;
} else if (pyarray_type == NPY_LONGLONG) {
// NPY_LONGLONG is equivalent to NPY_INT64, while their enum values
// might be different on certain platforms.
*out_tf_datatype = TF_INT64;
break;
}
return errors::Internal("Unsupported numpy type: ",
numpy_type_name(pyarray_type));
}
return Status::OK();
}
Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
PyObject** ptr_owner) {
*ptr_owner = nullptr;
@ -344,38 +186,6 @@ Status GetPyArrayDimensionsForTensor(const TF_Tensor* tensor,
return Status::OK();
}
// Determine the type description (PyArray_Descr) of a numpy ndarray to be
// created to represent an output Tensor.
Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
PyArray_Descr** descr) {
if (TF_TensorType(tensor) == TF_RESOURCE) {
PyObject* field = PyTuple_New(3);
#if PY_MAJOR_VERSION < 3
PyTuple_SetItem(field, 0, PyBytes_FromString("resource"));
#else
PyTuple_SetItem(field, 0, PyUnicode_FromString("resource"));
#endif
PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE));
PyTuple_SetItem(field, 2, PyLong_FromLong(1));
PyObject* fields = PyList_New(1);
PyList_SetItem(fields, 0, field);
int convert_result = PyArray_DescrConverter(fields, descr);
Py_CLEAR(field);
Py_CLEAR(fields);
if (convert_result != 1) {
return errors::Internal("Failed to create numpy array description for ",
"TF_RESOURCE-type tensor");
}
} else {
int type_num = -1;
TF_RETURN_IF_ERROR(
TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
*descr = PyArray_DescrFromType(type_num);
}
return Status::OK();
}
inline void FastMemcpy(void* dst, const void* src, size_t size) {
// clang-format off
switch (size) {
@ -461,7 +271,8 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
// Copy the TF_TensorData into a newly-created ndarray and return it.
PyArray_Descr* descr = nullptr;
TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(
static_cast<DataType>(TF_TensorType(tensor.get())), &descr));
Safe_PyObjectPtr safe_out_array =
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
if (!safe_out_array) {
@ -499,7 +310,11 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
// Convert numpy dtype to TensorFlow dtype.
TF_DataType dtype = TF_FLOAT;
TF_RETURN_IF_ERROR(PyArray_TYPE_to_TF_DataType(array, &dtype));
{
DataType tmp;
TF_RETURN_IF_ERROR(PyArray_DescrToDataType(PyArray_DESCR(array), &tmp));
dtype = static_cast<TF_DataType>(tmp);
}
tensorflow::int64 nelems = 1;
gtl::InlinedVector<int64_t, 4> dims;

View File

@ -13,16 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Must be included first.
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include <vector>
// Must be included first.
// clang-format: off
#include "tensorflow/python/lib/core/numpy.h"
// clang-format: on
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
namespace tensorflow {
@ -107,85 +110,6 @@ PyTypeObject TensorReleaserType = {
nullptr, /* tp_richcompare */
};
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
int* out_pyarray_type) {
switch (tf_datatype) {
case TF_HALF:
*out_pyarray_type = NPY_FLOAT16;
break;
case TF_FLOAT:
*out_pyarray_type = NPY_FLOAT32;
break;
case TF_DOUBLE:
*out_pyarray_type = NPY_FLOAT64;
break;
case TF_INT32:
*out_pyarray_type = NPY_INT32;
break;
case TF_UINT32:
*out_pyarray_type = NPY_UINT32;
break;
case TF_UINT8:
*out_pyarray_type = NPY_UINT8;
break;
case TF_UINT16:
*out_pyarray_type = NPY_UINT16;
break;
case TF_INT8:
*out_pyarray_type = NPY_INT8;
break;
case TF_INT16:
*out_pyarray_type = NPY_INT16;
break;
case TF_INT64:
*out_pyarray_type = NPY_INT64;
break;
case TF_UINT64:
*out_pyarray_type = NPY_UINT64;
break;
case TF_BOOL:
*out_pyarray_type = NPY_BOOL;
break;
case TF_COMPLEX64:
*out_pyarray_type = NPY_COMPLEX64;
break;
case TF_COMPLEX128:
*out_pyarray_type = NPY_COMPLEX128;
break;
case TF_STRING:
*out_pyarray_type = NPY_OBJECT;
break;
case TF_RESOURCE:
*out_pyarray_type = NPY_VOID;
break;
// TODO(keveman): These should be changed to NPY_VOID, and the type used for
// the resulting numpy array should be the custom struct types that we
// expect for quantized types.
case TF_QINT8:
*out_pyarray_type = NPY_INT8;
break;
case TF_QUINT8:
*out_pyarray_type = NPY_UINT8;
break;
case TF_QINT16:
*out_pyarray_type = NPY_INT16;
break;
case TF_QUINT16:
*out_pyarray_type = NPY_UINT16;
break;
case TF_QINT32:
*out_pyarray_type = NPY_INT32;
break;
case TF_BFLOAT16:
*out_pyarray_type = Bfloat16NumpyType();
break;
default:
return errors::Internal("Tensorflow type ", tf_datatype,
" not convertible to numpy dtype.");
}
return Status::OK();
}
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
std::function<void()> destructor, PyObject** result) {
if (dtype == DT_STRING || dtype == DT_RESOURCE) {
@ -193,15 +117,11 @@ Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
"Cannot convert string or resource Tensors.");
}
int type_num = -1;
Status s =
TF_DataType_to_PyArray_TYPE(static_cast<TF_DataType>(dtype), &type_num);
if (!s.ok()) {
return s;
}
PyArray_Descr* descr = nullptr;
TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(dtype, &descr));
auto* np_array = reinterpret_cast<PyArrayObject*>(
PyArray_SimpleNewFromData(dim_size, dims, type_num, data));
PyArray_SimpleNewFromData(dim_size, dims, descr->type_num, data));
CHECK_NE(np_array, nullptr);
PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA);
if (PyType_Ready(&TensorReleaserType) == -1) {
return errors::Unknown("Python type initialization failed.");

View File

@ -42,10 +42,6 @@ void ClearDecrefCache();
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
std::function<void()> destructor, PyObject** result);
// Converts TF_DataType to the corresponding numpy type.
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
int* out_pyarray_type);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_BRIDGE_H_

View File

@ -0,0 +1,287 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
#include <Python.h>
// Must be included first.
// clang-format: off
#include "tensorflow/python/lib/core/numpy.h"
// clang-format: on
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/python/lib/core/bfloat16.h"
namespace tensorflow {
PyArray_Descr* BFLOAT16_DESCR = nullptr;
PyArray_Descr* QINT8_DESCR = nullptr;
PyArray_Descr* QINT16_DESCR = nullptr;
PyArray_Descr* QINT32_DESCR = nullptr;
PyArray_Descr* QUINT8_DESCR = nullptr;
PyArray_Descr* QUINT16_DESCR = nullptr;
PyArray_Descr* RESOURCE_DESCR = nullptr;
// Define a struct array data type `[(tag, type_num)]`.
PyArray_Descr* DefineStructTypeAlias(const char* tag, int type_num) {
#if PY_MAJOR_VERSION < 3
auto* py_tag = PyBytes_FromString(tag);
#else
auto* py_tag = PyUnicode_FromString(tag);
#endif
auto* descr = PyArray_DescrFromType(type_num);
auto* py_tag_and_descr = PyTuple_Pack(2, py_tag, descr);
auto* obj = PyList_New(1);
PyList_SetItem(obj, 0, py_tag_and_descr);
PyArray_Descr* alias_descr = nullptr;
// TODO(slebedev): Switch to PyArray_DescrNewFromType because struct
// array dtypes could not be used for scalars. Note that this will
// require registering type conversions and UFunc specializations.
// See b/144230631.
CHECK_EQ(PyArray_DescrConverter(obj, &alias_descr), NPY_SUCCEED);
Py_DECREF(obj);
Py_DECREF(py_tag_and_descr);
Py_DECREF(py_tag);
Py_DECREF(descr);
CHECK_NE(alias_descr, nullptr);
return alias_descr;
}
void MaybeRegisterCustomNumPyTypes() {
static bool registered = false;
if (registered) return;
ImportNumpy(); // Ensure NumPy is loaded.
// Make sure the tags are consistent with DataTypeToPyArray_Descr.
QINT8_DESCR = DefineStructTypeAlias("qint8", NPY_INT8);
QINT16_DESCR = DefineStructTypeAlias("qint16", NPY_INT16);
QINT32_DESCR = DefineStructTypeAlias("qint32", NPY_INT32);
QUINT8_DESCR = DefineStructTypeAlias("quint8", NPY_UINT8);
QUINT16_DESCR = DefineStructTypeAlias("quint16", NPY_UINT16);
RESOURCE_DESCR = DefineStructTypeAlias("resource", NPY_UBYTE);
RegisterNumpyBfloat16();
BFLOAT16_DESCR = PyArray_DescrFromType(Bfloat16NumpyType());
registered = true;
}
const char* PyArray_DescrReprAsString(PyArray_Descr* descr) {
auto* descr_repr = PyObject_Repr(reinterpret_cast<PyObject*>(descr));
const char* result;
#if PY_MAJOR_VERSION < 3
result = PyBytes_AsString(descr_repr);
#else
auto* tmp = PyUnicode_AsASCIIString(descr_repr);
result = PyBytes_AsString(tmp);
Py_DECREF(tmp);
#endif
Py_DECREF(descr_repr);
return result;
}
Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr) {
switch (dt) {
case DT_HALF:
*out_descr = PyArray_DescrFromType(NPY_FLOAT16);
break;
case DT_FLOAT:
*out_descr = PyArray_DescrFromType(NPY_FLOAT32);
break;
case DT_DOUBLE:
*out_descr = PyArray_DescrFromType(NPY_FLOAT64);
break;
case DT_INT32:
*out_descr = PyArray_DescrFromType(NPY_INT32);
break;
case DT_UINT32:
*out_descr = PyArray_DescrFromType(NPY_UINT32);
break;
case DT_UINT8:
*out_descr = PyArray_DescrFromType(NPY_UINT8);
break;
case DT_UINT16:
*out_descr = PyArray_DescrFromType(NPY_UINT16);
break;
case DT_INT8:
*out_descr = PyArray_DescrFromType(NPY_INT8);
break;
case DT_INT16:
*out_descr = PyArray_DescrFromType(NPY_INT16);
break;
case DT_INT64:
*out_descr = PyArray_DescrFromType(NPY_INT64);
break;
case DT_UINT64:
*out_descr = PyArray_DescrFromType(NPY_UINT64);
break;
case DT_BOOL:
*out_descr = PyArray_DescrFromType(NPY_BOOL);
break;
case DT_COMPLEX64:
*out_descr = PyArray_DescrFromType(NPY_COMPLEX64);
break;
case DT_COMPLEX128:
*out_descr = PyArray_DescrFromType(NPY_COMPLEX128);
break;
case DT_STRING:
*out_descr = PyArray_DescrFromType(NPY_OBJECT);
break;
case DT_QINT8:
*out_descr = PyArray_DescrFromType(NPY_INT8);
break;
case DT_QINT16:
*out_descr = PyArray_DescrFromType(NPY_INT16);
break;
case DT_QINT32:
*out_descr = PyArray_DescrFromType(NPY_INT32);
break;
case DT_QUINT8:
*out_descr = PyArray_DescrFromType(NPY_UINT8);
break;
case DT_QUINT16:
*out_descr = PyArray_DescrFromType(NPY_UINT16);
break;
case DT_RESOURCE:
*out_descr = PyArray_DescrFromType(NPY_UBYTE);
break;
case DT_BFLOAT16:
Py_INCREF(BFLOAT16_DESCR);
*out_descr = BFLOAT16_DESCR;
break;
default:
return errors::Internal("TensorFlow data type ", DataType_Name(dt),
" cannot be converted to a NumPy data type.");
}
return Status::OK();
}
// NumPy defines fixed-width aliases for platform integer types. However,
// some types do not have a fixed-width alias. Specifically
//
// * on a LLP64 system NPY_INT32 == NPY_LONG therefore NPY_INT is not aliased;
// * on a LP64 system NPY_INT64 == NPY_LONG and NPY_LONGLONG is not aliased.
//
int MaybeResolveNumPyPlatformType(int type_num) {
switch (type_num) {
#if NPY_BITS_OF_INT == 32 && NPY_BITS_OF_LONGLONG == 32
case NPY_INT:
return NPY_INT32;
case NPY_UINT:
return NPY_UINT32;
#endif
#if NPY_BITSOF_INT == 32 && NPY_BITSOF_LONGLONG == 64
case NPY_LONGLONG:
return NPY_INT64;
case NPY_ULONGLONG:
return NPY_UINT64;
#endif
default:
return type_num;
}
}
Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt) {
const int type_num = MaybeResolveNumPyPlatformType(descr->type_num);
switch (type_num) {
case NPY_FLOAT16:
*out_dt = DT_HALF;
break;
case NPY_FLOAT32:
*out_dt = DT_FLOAT;
break;
case NPY_FLOAT64:
*out_dt = DT_DOUBLE;
break;
case NPY_INT8:
*out_dt = DT_INT8;
break;
case NPY_INT16:
*out_dt = DT_INT16;
break;
case NPY_INT32:
*out_dt = DT_INT32;
break;
case NPY_INT64:
*out_dt = DT_INT64;
break;
case NPY_UINT8:
*out_dt = DT_UINT8;
break;
case NPY_UINT16:
*out_dt = DT_UINT16;
break;
case NPY_UINT32:
*out_dt = DT_UINT32;
break;
case NPY_UINT64:
*out_dt = DT_UINT64;
break;
case NPY_BOOL:
*out_dt = DT_BOOL;
break;
case NPY_COMPLEX64:
*out_dt = DT_COMPLEX64;
break;
case NPY_COMPLEX128:
*out_dt = DT_COMPLEX128;
break;
case NPY_OBJECT:
case NPY_STRING:
case NPY_UNICODE:
*out_dt = DT_STRING;
break;
case NPY_VOID: {
if (descr == QINT8_DESCR) {
*out_dt = DT_QINT8;
break;
} else if (descr == QINT16_DESCR) {
*out_dt = DT_QINT16;
break;
} else if (descr == QINT32_DESCR) {
*out_dt = DT_QINT32;
break;
} else if (descr == QUINT8_DESCR) {
*out_dt = DT_QUINT8;
break;
} else if (descr == QUINT16_DESCR) {
*out_dt = DT_QUINT16;
break;
} else if (descr == RESOURCE_DESCR) {
*out_dt = DT_RESOURCE;
break;
}
return errors::Internal("Unsupported NumPy struct data type: ",
PyArray_DescrReprAsString(descr));
}
default:
if (type_num == Bfloat16NumpyType()) {
*out_dt = DT_BFLOAT16;
break;
}
return errors::Internal("Unregistered NumPy data type: ",
PyArray_DescrReprAsString(descr));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,65 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
// Must be included first.
// clang-format: off
#include "tensorflow/python/lib/core/numpy.h"
// clang-format: on
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
extern PyArray_Descr* QINT8_DESCR;
extern PyArray_Descr* QINT16_DESCR;
extern PyArray_Descr* QINT32_DESCR;
extern PyArray_Descr* QUINT8_DESCR;
extern PyArray_Descr* QUINT16_DESCR;
extern PyArray_Descr* RESOURCE_DESCR;
extern PyArray_Descr* BFLOAT16_DESCR;
// Register custom NumPy types.
//
// This function must be called in order to be able to map TensorFlow
// data types which do not have a corresponding standard NumPy data type
// (e.g. bfloat16 or qint8).
//
// TODO(b/144230631): The name is slightly misleading, as the function only
// registers bfloat16 and defines structured aliases for other data types
// (e.g. qint8).
void MaybeRegisterCustomNumPyTypes();
// Returns a NumPy data type matching a given tensorflow::DataType. If the
// function call succeeds, the caller is responsible for DECREF'ing the
// resulting PyArray_Descr*.
//
// NumPy does not support quantized integer types, so TensorFlow defines
// structured aliases for them, e.g. tf.qint8 is represented as
// np.dtype([("qint8", np.int8)]). However, for historical reasons this
// function does not use these aliases, and instead returns the *aliased*
// types (np.int8 in the example).
// TODO(b/144230631): Return an alias instead of the aliased type.
Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr);
// Returns a tensorflow::DataType corresponding to a given NumPy data type.
Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_

View File

@ -175,7 +175,7 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file):
else:
# If not a section header and not an empty line, then it's a symbol
# line. e.g. `tensorflow::swig::IsSequence`
symbols[curr_lib].append(line)
symbols[curr_lib].append(re.escape(line))
lib_paths = []
with open(lib_paths_file, "r") as f:

View File

@ -43,10 +43,6 @@ tensorflow::tfprof::SerializeToString
[graph_analyzer_tool] # graph_analyzer
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
[bfloat16_lib] # bfloat16
tensorflow::RegisterNumpyBfloat16
tensorflow::Bfloat16PyType
[events_writer] # events_writer
tensorflow::EventsWriter::Init
tensorflow::EventsWriter::InitWithSuffix
@ -189,3 +185,13 @@ tensorflow::Set_TF_Status_from_Status
[context] # tfe
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
[ndarray_tensor_types] # _dtypes
tensorflow::MaybeRegisterCustomNumPyTypes
tensorflow::BFLOAT16_DESCR
tensorflow::QINT8_DESCR
tensorflow::QINT16_DESCR
tensorflow::QINT32_DESCR
tensorflow::QUINT8_DESCR
tensorflow::QUINT16_DESCR
tensorflow::RESOURCE_DESCR