Moved custom NumPy data types to :ndarray_tensor_types
Prior to this change qint*, quint* and resource were defined in Python as a single-field struct arrays, e.g. _np_qint8 = np.dtype([("qint8", np.int8)]) Several TensorFlow functions had to special-case struct arrays (which otherwise have type NPY_VOID) and infer the real dtype from struct fields. Having these data types defined and handled in C++ allows to minimize magic on the Python/C++ boundary. PiperOrigin-RevId: 286436332 Change-Id: I20581b6999efbea02aa0efeccaa2d8889ceadaf2
This commit is contained in:
parent
7c23f557e0
commit
56d41be853
@ -394,6 +394,8 @@ cc_library(
|
|||||||
srcs = ["lib/core/numpy.cc"],
|
srcs = ["lib/core/numpy.cc"],
|
||||||
hdrs = ["lib/core/numpy.h"],
|
hdrs = ["lib/core/numpy.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"//third_party/py/numpy:headers",
|
"//third_party/py/numpy:headers",
|
||||||
"//third_party/python_runtime:headers",
|
"//third_party/python_runtime:headers",
|
||||||
],
|
],
|
||||||
@ -405,12 +407,24 @@ cc_library(
|
|||||||
hdrs = ["lib/core/bfloat16.h"],
|
hdrs = ["lib/core/bfloat16.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":numpy_lib",
|
":numpy_lib",
|
||||||
|
":safe_ptr",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//third_party/python_runtime:headers",
|
"//third_party/python_runtime:headers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_bfloat16",
|
||||||
|
srcs = ["lib/core/bfloat16_wrapper.cc"],
|
||||||
|
hdrs = ["lib/core/bfloat16.h"],
|
||||||
|
module_name = "_pywrap_bfloat16",
|
||||||
|
deps = [
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ndarray_tensor_bridge",
|
name = "ndarray_tensor_bridge",
|
||||||
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
|
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
|
||||||
@ -421,7 +435,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
deps = [
|
deps = [
|
||||||
":ndarray_tensor_types",
|
":bfloat16_lib",
|
||||||
":numpy_lib",
|
":numpy_lib",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -782,31 +796,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "ndarray_tensor_types",
|
|
||||||
srcs = ["lib/core/ndarray_tensor_types.cc"],
|
|
||||||
hdrs = ["lib/core/ndarray_tensor_types.h"],
|
|
||||||
deps = [
|
|
||||||
":bfloat16_lib",
|
|
||||||
":numpy_lib",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//third_party/python_runtime:headers",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "ndarray_tensor_types_headers_lib",
|
|
||||||
hdrs = ["lib/core/ndarray_tensor_types.h"],
|
|
||||||
deps = [
|
|
||||||
":numpy_lib",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ndarray_tensor",
|
name = "ndarray_tensor",
|
||||||
srcs = ["lib/core/ndarray_tensor.cc"],
|
srcs = ["lib/core/ndarray_tensor.cc"],
|
||||||
@ -815,8 +804,8 @@ cc_library(
|
|||||||
"//learning/deepmind/courier:__subpackages__",
|
"//learning/deepmind/courier:__subpackages__",
|
||||||
]),
|
]),
|
||||||
deps = [
|
deps = [
|
||||||
|
":bfloat16_lib",
|
||||||
":ndarray_tensor_bridge",
|
":ndarray_tensor_bridge",
|
||||||
":ndarray_tensor_types",
|
|
||||||
":numpy_lib",
|
":numpy_lib",
|
||||||
":safe_ptr",
|
":safe_ptr",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
@ -1176,7 +1165,6 @@ tf_python_pybind_extension(
|
|||||||
srcs = ["framework/dtypes.cc"],
|
srcs = ["framework/dtypes.cc"],
|
||||||
module_name = "_dtypes",
|
module_name = "_dtypes",
|
||||||
deps = [
|
deps = [
|
||||||
":ndarray_tensor_types_headers_lib",
|
|
||||||
"//tensorflow/core:framework_headers_lib",
|
"//tensorflow/core:framework_headers_lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
@ -1190,6 +1178,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":_dtypes",
|
":_dtypes",
|
||||||
|
":_pywrap_bfloat16",
|
||||||
":pywrap_tensorflow",
|
":pywrap_tensorflow",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
],
|
],
|
||||||
@ -5518,6 +5507,7 @@ tf_py_wrap_cc(
|
|||||||
"//conditions:default": None,
|
"//conditions:default": None,
|
||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
|
":bfloat16_lib",
|
||||||
":cost_analyzer_lib",
|
":cost_analyzer_lib",
|
||||||
":model_analyzer_lib",
|
":model_analyzer_lib",
|
||||||
":cpp_python_util",
|
":cpp_python_util",
|
||||||
@ -5587,8 +5577,7 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
|||||||
":numpy_lib", # checkpoint_reader
|
":numpy_lib", # checkpoint_reader
|
||||||
":safe_ptr", # checkpoint_reader
|
":safe_ptr", # checkpoint_reader
|
||||||
":python_op_gen", # python_op_gen
|
":python_op_gen", # python_op_gen
|
||||||
":bfloat16_lib", # _dtypes
|
":bfloat16_lib", # bfloat16
|
||||||
":ndarray_tensor_types", # _dtypes
|
|
||||||
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
|
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
|
||||||
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||||
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
|
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
|
||||||
@ -6215,6 +6204,7 @@ cuda_py_test(
|
|||||||
":client_testlib",
|
":client_testlib",
|
||||||
":constant_op",
|
":constant_op",
|
||||||
":dtypes",
|
":dtypes",
|
||||||
|
":framework_for_generated_wrappers",
|
||||||
":framework_ops",
|
":framework_ops",
|
||||||
":training",
|
":training",
|
||||||
":variable_scope",
|
":variable_scope",
|
||||||
|
@ -156,6 +156,7 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":convert",
|
":convert",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -44,7 +44,6 @@ cc_library(
|
|||||||
"//tensorflow/python:cpp_python_util",
|
"//tensorflow/python:cpp_python_util",
|
||||||
"//tensorflow/python:ndarray_tensor",
|
"//tensorflow/python:ndarray_tensor",
|
||||||
"//tensorflow/python:ndarray_tensor_bridge",
|
"//tensorflow/python:ndarray_tensor_bridge",
|
||||||
"//tensorflow/python:ndarray_tensor_types",
|
|
||||||
"//tensorflow/python:numpy_lib",
|
"//tensorflow/python:numpy_lib",
|
||||||
"//tensorflow/python:py_seq_tensor",
|
"//tensorflow/python:py_seq_tensor",
|
||||||
"//tensorflow/python:safe_ptr",
|
"//tensorflow/python:safe_ptr",
|
||||||
|
@ -31,7 +31,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
@ -289,15 +288,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
|||||||
if (PyArray_Check(value)) {
|
if (PyArray_Check(value)) {
|
||||||
int desired_np_dtype = -1;
|
int desired_np_dtype = -1;
|
||||||
if (dtype != tensorflow::DT_INVALID) {
|
if (dtype != tensorflow::DT_INVALID) {
|
||||||
PyArray_Descr* descr = nullptr;
|
if (!tensorflow::TF_DataType_to_PyArray_TYPE(
|
||||||
if (!tensorflow::DataTypeToPyArray_Descr(dtype, &descr).ok()) {
|
static_cast<TF_DataType>(dtype), &desired_np_dtype)
|
||||||
|
.ok()) {
|
||||||
PyErr_SetString(
|
PyErr_SetString(
|
||||||
PyExc_TypeError,
|
PyExc_TypeError,
|
||||||
tensorflow::strings::StrCat("Invalid dtype argument value ", dtype)
|
tensorflow::strings::StrCat("Invalid dtype argument value ", dtype)
|
||||||
.c_str());
|
.c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
desired_np_dtype = descr->type_num;
|
|
||||||
}
|
}
|
||||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
||||||
int current_np_dtype = PyArray_TYPE(array);
|
int current_np_dtype = PyArray_TYPE(array);
|
||||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
|||||||
#include "include/pybind11/pybind11.h"
|
#include "include/pybind11/pybind11.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -61,18 +60,6 @@ inline bool DataTypeIsNumPyCompatible(DataType dt) {
|
|||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
PYBIND11_MODULE(_dtypes, m) {
|
PYBIND11_MODULE(_dtypes, m) {
|
||||||
tensorflow::MaybeRegisterCustomNumPyTypes();
|
|
||||||
|
|
||||||
m.attr("np_bfloat16") =
|
|
||||||
reinterpret_cast<PyObject*>(tensorflow::BFLOAT16_DESCR);
|
|
||||||
m.attr("np_qint8") = reinterpret_cast<PyObject*>(tensorflow::QINT8_DESCR);
|
|
||||||
m.attr("np_qint16") = reinterpret_cast<PyObject*>(tensorflow::QINT16_DESCR);
|
|
||||||
m.attr("np_qint32") = reinterpret_cast<PyObject*>(tensorflow::QINT32_DESCR);
|
|
||||||
m.attr("np_quint8") = reinterpret_cast<PyObject*>(tensorflow::QUINT8_DESCR);
|
|
||||||
m.attr("np_quint16") = reinterpret_cast<PyObject*>(tensorflow::QUINT16_DESCR);
|
|
||||||
m.attr("np_resource") =
|
|
||||||
reinterpret_cast<PyObject*>(tensorflow::RESOURCE_DESCR);
|
|
||||||
|
|
||||||
py::class_<tensorflow::DataType>(m, "DType")
|
py::class_<tensorflow::DataType>(m, "DType")
|
||||||
.def(py::init([](py::object obj) {
|
.def(py::init([](py::object obj) {
|
||||||
auto id = static_cast<int>(py::int_(obj));
|
auto id = static_cast<int>(py::int_(obj));
|
||||||
|
@ -20,16 +20,17 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import builtins
|
from six.moves import builtins
|
||||||
|
|
||||||
# TODO(b/143110113): This import has to come first. This is a temporary
|
|
||||||
# workaround which fixes repeated proto registration on macOS.
|
|
||||||
# pylint: disable=g-bad-import-order, unused-import
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
# pylint: enable=g-bad-import-order, unused-import
|
|
||||||
|
|
||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
|
# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
|
||||||
|
# protobuf errors where a file is defined twice on MacOS.
|
||||||
|
# pylint: disable=invalid-import-order,g-bad-import-order
|
||||||
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
|
from tensorflow.python import _pywrap_bfloat16
|
||||||
from tensorflow.python import _dtypes
|
from tensorflow.python import _dtypes
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=slots-on-old-class
|
# pylint: disable=slots-on-old-class
|
||||||
@tf_export("dtypes.DType", "DType")
|
@tf_export("dtypes.DType", "DType")
|
||||||
@ -424,18 +425,20 @@ _STRING_TO_TF["double_ref"] = float64_ref
|
|||||||
|
|
||||||
# Numpy representation for quantized dtypes.
|
# Numpy representation for quantized dtypes.
|
||||||
#
|
#
|
||||||
_np_qint8 = _dtypes.np_qint8
|
# These are magic strings that are used in the swig wrapper to identify
|
||||||
_np_qint16 = _dtypes.np_qint16
|
# quantized types.
|
||||||
_np_qint32 = _dtypes.np_qint32
|
# TODO(mrry,keveman): Investigate Numpy type registration to replace this
|
||||||
_np_quint8 = _dtypes.np_quint8
|
# hard-coding of names.
|
||||||
_np_quint16 = _dtypes.np_quint16
|
_np_qint8 = np.dtype([("qint8", np.int8)])
|
||||||
|
_np_quint8 = np.dtype([("quint8", np.uint8)])
|
||||||
|
_np_qint16 = np.dtype([("qint16", np.int16)])
|
||||||
|
_np_quint16 = np.dtype([("quint16", np.uint16)])
|
||||||
|
_np_qint32 = np.dtype([("qint32", np.int32)])
|
||||||
|
|
||||||
# Technically, _np_bfloat does not have to be a Python class, but existing
|
# _np_bfloat16 is defined by a module import.
|
||||||
# code expects it to.
|
|
||||||
_np_bfloat16 = _dtypes.np_bfloat16.type
|
|
||||||
|
|
||||||
# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
|
# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
|
||||||
np_resource = _dtypes.np_resource
|
np_resource = np.dtype([("resource", np.ubyte)])
|
||||||
|
|
||||||
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
|
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
|
||||||
_NP_TO_TF = {
|
_NP_TO_TF = {
|
||||||
|
@ -321,7 +321,7 @@ class PyFuncTest(PyFuncTestBase):
|
|||||||
y, = script_ops.py_func(bad, [], [dtypes.float32])
|
y, = script_ops.py_func(bad, [], [dtypes.float32])
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors.InternalError,
|
with self.assertRaisesRegexp(errors.InternalError,
|
||||||
"Unsupported NumPy struct data type"):
|
"Unsupported numpy data type"):
|
||||||
self.evaluate(y)
|
self.evaluate(y)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
@ -21,19 +21,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct PyDecrefDeleter {
|
|
||||||
void operator()(PyObject* p) const { Py_DECREF(p); }
|
|
||||||
};
|
|
||||||
|
|
||||||
using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
|
|
||||||
Safe_PyObjectPtr make_safe(PyObject* object) {
|
|
||||||
return Safe_PyObjectPtr(object);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Workarounds for Python 2 vs 3 API differences.
|
// Workarounds for Python 2 vs 3 API differences.
|
||||||
#if PY_MAJOR_VERSION < 3
|
#if PY_MAJOR_VERSION < 3
|
||||||
|
|
||||||
|
@ -24,10 +24,12 @@ import math
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# pylint: disable=unused-import,g-bad-import-order
|
# pylint: disable=unused-import,g-bad-import-order
|
||||||
|
from tensorflow.python import _pywrap_bfloat16
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
bfloat16 = dtypes._np_bfloat16 # pylint: disable=protected-access
|
|
||||||
|
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||||
|
|
||||||
|
|
||||||
class Bfloat16Test(test.TestCase):
|
class Bfloat16Test(test.TestCase):
|
||||||
|
24
tensorflow/python/lib/core/bfloat16_wrapper.cc
Normal file
24
tensorflow/python/lib/core/bfloat16_wrapper.cc
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_bfloat16, m) {
|
||||||
|
tensorflow::RegisterNumpyBfloat16();
|
||||||
|
|
||||||
|
m.def("TF_bfloat16_type",
|
||||||
|
[] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
|
||||||
|
}
|
@ -21,13 +21,171 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
char const* numpy_type_name(int numpy_type) {
|
||||||
|
switch (numpy_type) {
|
||||||
|
#define TYPE_CASE(s) \
|
||||||
|
case s: \
|
||||||
|
return #s
|
||||||
|
|
||||||
|
TYPE_CASE(NPY_BOOL);
|
||||||
|
TYPE_CASE(NPY_BYTE);
|
||||||
|
TYPE_CASE(NPY_UBYTE);
|
||||||
|
TYPE_CASE(NPY_SHORT);
|
||||||
|
TYPE_CASE(NPY_USHORT);
|
||||||
|
TYPE_CASE(NPY_INT);
|
||||||
|
TYPE_CASE(NPY_UINT);
|
||||||
|
TYPE_CASE(NPY_LONG);
|
||||||
|
TYPE_CASE(NPY_ULONG);
|
||||||
|
TYPE_CASE(NPY_LONGLONG);
|
||||||
|
TYPE_CASE(NPY_ULONGLONG);
|
||||||
|
TYPE_CASE(NPY_FLOAT);
|
||||||
|
TYPE_CASE(NPY_DOUBLE);
|
||||||
|
TYPE_CASE(NPY_LONGDOUBLE);
|
||||||
|
TYPE_CASE(NPY_CFLOAT);
|
||||||
|
TYPE_CASE(NPY_CDOUBLE);
|
||||||
|
TYPE_CASE(NPY_CLONGDOUBLE);
|
||||||
|
TYPE_CASE(NPY_OBJECT);
|
||||||
|
TYPE_CASE(NPY_STRING);
|
||||||
|
TYPE_CASE(NPY_UNICODE);
|
||||||
|
TYPE_CASE(NPY_VOID);
|
||||||
|
TYPE_CASE(NPY_DATETIME);
|
||||||
|
TYPE_CASE(NPY_TIMEDELTA);
|
||||||
|
TYPE_CASE(NPY_HALF);
|
||||||
|
TYPE_CASE(NPY_NTYPES);
|
||||||
|
TYPE_CASE(NPY_NOTYPE);
|
||||||
|
TYPE_CASE(NPY_CHAR);
|
||||||
|
TYPE_CASE(NPY_USERDEF);
|
||||||
|
default:
|
||||||
|
return "not a numpy type";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
|
||||||
|
TF_DataType* out_tf_datatype) {
|
||||||
|
PyObject* key;
|
||||||
|
PyObject* value;
|
||||||
|
Py_ssize_t pos = 0;
|
||||||
|
if (PyDict_Next(descr->fields, &pos, &key, &value)) {
|
||||||
|
// In Python 3, the keys of numpy custom struct types are unicode, unlike
|
||||||
|
// Python 2, where the keys are bytes.
|
||||||
|
const char* key_string =
|
||||||
|
PyBytes_Check(key) ? PyBytes_AsString(key)
|
||||||
|
: PyBytes_AsString(PyUnicode_AsASCIIString(key));
|
||||||
|
if (!key_string) {
|
||||||
|
return errors::Internal("Corrupt numpy type descriptor");
|
||||||
|
}
|
||||||
|
tensorflow::string key = key_string;
|
||||||
|
// The typenames here should match the field names in the custom struct
|
||||||
|
// types constructed in test_util.py.
|
||||||
|
// TODO(mrry,keveman): Investigate Numpy type registration to replace this
|
||||||
|
// hard-coding of names.
|
||||||
|
if (key == "quint8") {
|
||||||
|
*out_tf_datatype = TF_QUINT8;
|
||||||
|
} else if (key == "qint8") {
|
||||||
|
*out_tf_datatype = TF_QINT8;
|
||||||
|
} else if (key == "qint16") {
|
||||||
|
*out_tf_datatype = TF_QINT16;
|
||||||
|
} else if (key == "quint16") {
|
||||||
|
*out_tf_datatype = TF_QUINT16;
|
||||||
|
} else if (key == "qint32") {
|
||||||
|
*out_tf_datatype = TF_QINT32;
|
||||||
|
} else if (key == "resource") {
|
||||||
|
*out_tf_datatype = TF_RESOURCE;
|
||||||
|
} else {
|
||||||
|
return errors::Internal("Unsupported numpy data type");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return errors::Internal("Unsupported numpy data type");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
|
||||||
|
TF_DataType* out_tf_datatype) {
|
||||||
|
int pyarray_type = PyArray_TYPE(array);
|
||||||
|
PyArray_Descr* descr = PyArray_DESCR(array);
|
||||||
|
switch (pyarray_type) {
|
||||||
|
case NPY_FLOAT16:
|
||||||
|
*out_tf_datatype = TF_HALF;
|
||||||
|
break;
|
||||||
|
case NPY_FLOAT32:
|
||||||
|
*out_tf_datatype = TF_FLOAT;
|
||||||
|
break;
|
||||||
|
case NPY_FLOAT64:
|
||||||
|
*out_tf_datatype = TF_DOUBLE;
|
||||||
|
break;
|
||||||
|
case NPY_INT32:
|
||||||
|
*out_tf_datatype = TF_INT32;
|
||||||
|
break;
|
||||||
|
case NPY_UINT8:
|
||||||
|
*out_tf_datatype = TF_UINT8;
|
||||||
|
break;
|
||||||
|
case NPY_UINT16:
|
||||||
|
*out_tf_datatype = TF_UINT16;
|
||||||
|
break;
|
||||||
|
case NPY_UINT32:
|
||||||
|
*out_tf_datatype = TF_UINT32;
|
||||||
|
break;
|
||||||
|
case NPY_UINT64:
|
||||||
|
*out_tf_datatype = TF_UINT64;
|
||||||
|
break;
|
||||||
|
case NPY_INT8:
|
||||||
|
*out_tf_datatype = TF_INT8;
|
||||||
|
break;
|
||||||
|
case NPY_INT16:
|
||||||
|
*out_tf_datatype = TF_INT16;
|
||||||
|
break;
|
||||||
|
case NPY_INT64:
|
||||||
|
*out_tf_datatype = TF_INT64;
|
||||||
|
break;
|
||||||
|
case NPY_BOOL:
|
||||||
|
*out_tf_datatype = TF_BOOL;
|
||||||
|
break;
|
||||||
|
case NPY_COMPLEX64:
|
||||||
|
*out_tf_datatype = TF_COMPLEX64;
|
||||||
|
break;
|
||||||
|
case NPY_COMPLEX128:
|
||||||
|
*out_tf_datatype = TF_COMPLEX128;
|
||||||
|
break;
|
||||||
|
case NPY_OBJECT:
|
||||||
|
case NPY_STRING:
|
||||||
|
case NPY_UNICODE:
|
||||||
|
*out_tf_datatype = TF_STRING;
|
||||||
|
break;
|
||||||
|
case NPY_VOID:
|
||||||
|
// Quantized types are currently represented as custom struct types.
|
||||||
|
// PyArray_TYPE returns NPY_VOID for structs, and we should look into
|
||||||
|
// descr to derive the actual type.
|
||||||
|
// Direct feeds of certain types of ResourceHandles are represented as a
|
||||||
|
// custom struct type.
|
||||||
|
return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
|
||||||
|
default:
|
||||||
|
if (pyarray_type == Bfloat16NumpyType()) {
|
||||||
|
*out_tf_datatype = TF_BFLOAT16;
|
||||||
|
break;
|
||||||
|
} else if (pyarray_type == NPY_ULONGLONG) {
|
||||||
|
// NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values
|
||||||
|
// might be different on certain platforms.
|
||||||
|
*out_tf_datatype = TF_UINT64;
|
||||||
|
break;
|
||||||
|
} else if (pyarray_type == NPY_LONGLONG) {
|
||||||
|
// NPY_LONGLONG is equivalent to NPY_INT64, while their enum values
|
||||||
|
// might be different on certain platforms.
|
||||||
|
*out_tf_datatype = TF_INT64;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return errors::Internal("Unsupported numpy type: ",
|
||||||
|
numpy_type_name(pyarray_type));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
|
Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
|
||||||
PyObject** ptr_owner) {
|
PyObject** ptr_owner) {
|
||||||
*ptr_owner = nullptr;
|
*ptr_owner = nullptr;
|
||||||
@ -186,6 +344,38 @@ Status GetPyArrayDimensionsForTensor(const TF_Tensor* tensor,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine the type description (PyArray_Descr) of a numpy ndarray to be
|
||||||
|
// created to represent an output Tensor.
|
||||||
|
Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
|
||||||
|
PyArray_Descr** descr) {
|
||||||
|
if (TF_TensorType(tensor) == TF_RESOURCE) {
|
||||||
|
PyObject* field = PyTuple_New(3);
|
||||||
|
#if PY_MAJOR_VERSION < 3
|
||||||
|
PyTuple_SetItem(field, 0, PyBytes_FromString("resource"));
|
||||||
|
#else
|
||||||
|
PyTuple_SetItem(field, 0, PyUnicode_FromString("resource"));
|
||||||
|
#endif
|
||||||
|
PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE));
|
||||||
|
PyTuple_SetItem(field, 2, PyLong_FromLong(1));
|
||||||
|
PyObject* fields = PyList_New(1);
|
||||||
|
PyList_SetItem(fields, 0, field);
|
||||||
|
int convert_result = PyArray_DescrConverter(fields, descr);
|
||||||
|
Py_CLEAR(field);
|
||||||
|
Py_CLEAR(fields);
|
||||||
|
if (convert_result != 1) {
|
||||||
|
return errors::Internal("Failed to create numpy array description for ",
|
||||||
|
"TF_RESOURCE-type tensor");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int type_num = -1;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
|
||||||
|
*descr = PyArray_DescrFromType(type_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
inline void FastMemcpy(void* dst, const void* src, size_t size) {
|
inline void FastMemcpy(void* dst, const void* src, size_t size) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
switch (size) {
|
switch (size) {
|
||||||
@ -271,8 +461,7 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
|||||||
|
|
||||||
// Copy the TF_TensorData into a newly-created ndarray and return it.
|
// Copy the TF_TensorData into a newly-created ndarray and return it.
|
||||||
PyArray_Descr* descr = nullptr;
|
PyArray_Descr* descr = nullptr;
|
||||||
TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(
|
TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
|
||||||
static_cast<DataType>(TF_TensorType(tensor.get())), &descr));
|
|
||||||
Safe_PyObjectPtr safe_out_array =
|
Safe_PyObjectPtr safe_out_array =
|
||||||
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
|
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
|
||||||
if (!safe_out_array) {
|
if (!safe_out_array) {
|
||||||
@ -310,11 +499,7 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
|||||||
|
|
||||||
// Convert numpy dtype to TensorFlow dtype.
|
// Convert numpy dtype to TensorFlow dtype.
|
||||||
TF_DataType dtype = TF_FLOAT;
|
TF_DataType dtype = TF_FLOAT;
|
||||||
{
|
TF_RETURN_IF_ERROR(PyArray_TYPE_to_TF_DataType(array, &dtype));
|
||||||
DataType tmp;
|
|
||||||
TF_RETURN_IF_ERROR(PyArray_DescrToDataType(PyArray_DESCR(array), &tmp));
|
|
||||||
dtype = static_cast<TF_DataType>(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::int64 nelems = 1;
|
tensorflow::int64 nelems = 1;
|
||||||
gtl::InlinedVector<int64_t, 4> dims;
|
gtl::InlinedVector<int64_t, 4> dims;
|
||||||
|
@ -13,19 +13,16 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
// Must be included first.
|
||||||
|
#include "tensorflow/python/lib/core/numpy.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Must be included first.
|
|
||||||
// clang-format: off
|
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
|
||||||
// clang-format: on
|
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||||
|
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -110,6 +107,85 @@ PyTypeObject TensorReleaserType = {
|
|||||||
nullptr, /* tp_richcompare */
|
nullptr, /* tp_richcompare */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
|
||||||
|
int* out_pyarray_type) {
|
||||||
|
switch (tf_datatype) {
|
||||||
|
case TF_HALF:
|
||||||
|
*out_pyarray_type = NPY_FLOAT16;
|
||||||
|
break;
|
||||||
|
case TF_FLOAT:
|
||||||
|
*out_pyarray_type = NPY_FLOAT32;
|
||||||
|
break;
|
||||||
|
case TF_DOUBLE:
|
||||||
|
*out_pyarray_type = NPY_FLOAT64;
|
||||||
|
break;
|
||||||
|
case TF_INT32:
|
||||||
|
*out_pyarray_type = NPY_INT32;
|
||||||
|
break;
|
||||||
|
case TF_UINT32:
|
||||||
|
*out_pyarray_type = NPY_UINT32;
|
||||||
|
break;
|
||||||
|
case TF_UINT8:
|
||||||
|
*out_pyarray_type = NPY_UINT8;
|
||||||
|
break;
|
||||||
|
case TF_UINT16:
|
||||||
|
*out_pyarray_type = NPY_UINT16;
|
||||||
|
break;
|
||||||
|
case TF_INT8:
|
||||||
|
*out_pyarray_type = NPY_INT8;
|
||||||
|
break;
|
||||||
|
case TF_INT16:
|
||||||
|
*out_pyarray_type = NPY_INT16;
|
||||||
|
break;
|
||||||
|
case TF_INT64:
|
||||||
|
*out_pyarray_type = NPY_INT64;
|
||||||
|
break;
|
||||||
|
case TF_UINT64:
|
||||||
|
*out_pyarray_type = NPY_UINT64;
|
||||||
|
break;
|
||||||
|
case TF_BOOL:
|
||||||
|
*out_pyarray_type = NPY_BOOL;
|
||||||
|
break;
|
||||||
|
case TF_COMPLEX64:
|
||||||
|
*out_pyarray_type = NPY_COMPLEX64;
|
||||||
|
break;
|
||||||
|
case TF_COMPLEX128:
|
||||||
|
*out_pyarray_type = NPY_COMPLEX128;
|
||||||
|
break;
|
||||||
|
case TF_STRING:
|
||||||
|
*out_pyarray_type = NPY_OBJECT;
|
||||||
|
break;
|
||||||
|
case TF_RESOURCE:
|
||||||
|
*out_pyarray_type = NPY_VOID;
|
||||||
|
break;
|
||||||
|
// TODO(keveman): These should be changed to NPY_VOID, and the type used for
|
||||||
|
// the resulting numpy array should be the custom struct types that we
|
||||||
|
// expect for quantized types.
|
||||||
|
case TF_QINT8:
|
||||||
|
*out_pyarray_type = NPY_INT8;
|
||||||
|
break;
|
||||||
|
case TF_QUINT8:
|
||||||
|
*out_pyarray_type = NPY_UINT8;
|
||||||
|
break;
|
||||||
|
case TF_QINT16:
|
||||||
|
*out_pyarray_type = NPY_INT16;
|
||||||
|
break;
|
||||||
|
case TF_QUINT16:
|
||||||
|
*out_pyarray_type = NPY_UINT16;
|
||||||
|
break;
|
||||||
|
case TF_QINT32:
|
||||||
|
*out_pyarray_type = NPY_INT32;
|
||||||
|
break;
|
||||||
|
case TF_BFLOAT16:
|
||||||
|
*out_pyarray_type = Bfloat16NumpyType();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return errors::Internal("Tensorflow type ", tf_datatype,
|
||||||
|
" not convertible to numpy dtype.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
||||||
std::function<void()> destructor, PyObject** result) {
|
std::function<void()> destructor, PyObject** result) {
|
||||||
if (dtype == DT_STRING || dtype == DT_RESOURCE) {
|
if (dtype == DT_STRING || dtype == DT_RESOURCE) {
|
||||||
@ -117,11 +193,15 @@ Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
|||||||
"Cannot convert string or resource Tensors.");
|
"Cannot convert string or resource Tensors.");
|
||||||
}
|
}
|
||||||
|
|
||||||
PyArray_Descr* descr = nullptr;
|
int type_num = -1;
|
||||||
TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(dtype, &descr));
|
Status s =
|
||||||
|
TF_DataType_to_PyArray_TYPE(static_cast<TF_DataType>(dtype), &type_num);
|
||||||
|
if (!s.ok()) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
auto* np_array = reinterpret_cast<PyArrayObject*>(
|
auto* np_array = reinterpret_cast<PyArrayObject*>(
|
||||||
PyArray_SimpleNewFromData(dim_size, dims, descr->type_num, data));
|
PyArray_SimpleNewFromData(dim_size, dims, type_num, data));
|
||||||
CHECK_NE(np_array, nullptr);
|
|
||||||
PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA);
|
PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA);
|
||||||
if (PyType_Ready(&TensorReleaserType) == -1) {
|
if (PyType_Ready(&TensorReleaserType) == -1) {
|
||||||
return errors::Unknown("Python type initialization failed.");
|
return errors::Unknown("Python type initialization failed.");
|
||||||
|
@ -42,6 +42,10 @@ void ClearDecrefCache();
|
|||||||
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
||||||
std::function<void()> destructor, PyObject** result);
|
std::function<void()> destructor, PyObject** result);
|
||||||
|
|
||||||
|
// Converts TF_DataType to the corresponding numpy type.
|
||||||
|
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
|
||||||
|
int* out_pyarray_type);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_BRIDGE_H_
|
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_BRIDGE_H_
|
||||||
|
@ -1,287 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
|
||||||
|
|
||||||
#include <Python.h>
|
|
||||||
|
|
||||||
// Must be included first.
|
|
||||||
// clang-format: off
|
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
|
||||||
// clang-format: on
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
|
||||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
PyArray_Descr* BFLOAT16_DESCR = nullptr;
|
|
||||||
PyArray_Descr* QINT8_DESCR = nullptr;
|
|
||||||
PyArray_Descr* QINT16_DESCR = nullptr;
|
|
||||||
PyArray_Descr* QINT32_DESCR = nullptr;
|
|
||||||
PyArray_Descr* QUINT8_DESCR = nullptr;
|
|
||||||
PyArray_Descr* QUINT16_DESCR = nullptr;
|
|
||||||
PyArray_Descr* RESOURCE_DESCR = nullptr;
|
|
||||||
|
|
||||||
// Define a struct array data type `[(tag, type_num)]`.
|
|
||||||
PyArray_Descr* DefineStructTypeAlias(const char* tag, int type_num) {
|
|
||||||
#if PY_MAJOR_VERSION < 3
|
|
||||||
auto* py_tag = PyBytes_FromString(tag);
|
|
||||||
#else
|
|
||||||
auto* py_tag = PyUnicode_FromString(tag);
|
|
||||||
#endif
|
|
||||||
auto* descr = PyArray_DescrFromType(type_num);
|
|
||||||
auto* py_tag_and_descr = PyTuple_Pack(2, py_tag, descr);
|
|
||||||
auto* obj = PyList_New(1);
|
|
||||||
PyList_SetItem(obj, 0, py_tag_and_descr);
|
|
||||||
PyArray_Descr* alias_descr = nullptr;
|
|
||||||
// TODO(slebedev): Switch to PyArray_DescrNewFromType because struct
|
|
||||||
// array dtypes could not be used for scalars. Note that this will
|
|
||||||
// require registering type conversions and UFunc specializations.
|
|
||||||
// See b/144230631.
|
|
||||||
CHECK_EQ(PyArray_DescrConverter(obj, &alias_descr), NPY_SUCCEED);
|
|
||||||
Py_DECREF(obj);
|
|
||||||
Py_DECREF(py_tag_and_descr);
|
|
||||||
Py_DECREF(py_tag);
|
|
||||||
Py_DECREF(descr);
|
|
||||||
CHECK_NE(alias_descr, nullptr);
|
|
||||||
return alias_descr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MaybeRegisterCustomNumPyTypes() {
|
|
||||||
static bool registered = false;
|
|
||||||
if (registered) return;
|
|
||||||
ImportNumpy(); // Ensure NumPy is loaded.
|
|
||||||
// Make sure the tags are consistent with DataTypeToPyArray_Descr.
|
|
||||||
QINT8_DESCR = DefineStructTypeAlias("qint8", NPY_INT8);
|
|
||||||
QINT16_DESCR = DefineStructTypeAlias("qint16", NPY_INT16);
|
|
||||||
QINT32_DESCR = DefineStructTypeAlias("qint32", NPY_INT32);
|
|
||||||
QUINT8_DESCR = DefineStructTypeAlias("quint8", NPY_UINT8);
|
|
||||||
QUINT16_DESCR = DefineStructTypeAlias("quint16", NPY_UINT16);
|
|
||||||
RESOURCE_DESCR = DefineStructTypeAlias("resource", NPY_UBYTE);
|
|
||||||
RegisterNumpyBfloat16();
|
|
||||||
BFLOAT16_DESCR = PyArray_DescrFromType(Bfloat16NumpyType());
|
|
||||||
registered = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char* PyArray_DescrReprAsString(PyArray_Descr* descr) {
|
|
||||||
auto* descr_repr = PyObject_Repr(reinterpret_cast<PyObject*>(descr));
|
|
||||||
const char* result;
|
|
||||||
#if PY_MAJOR_VERSION < 3
|
|
||||||
result = PyBytes_AsString(descr_repr);
|
|
||||||
#else
|
|
||||||
auto* tmp = PyUnicode_AsASCIIString(descr_repr);
|
|
||||||
result = PyBytes_AsString(tmp);
|
|
||||||
Py_DECREF(tmp);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
Py_DECREF(descr_repr);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr) {
|
|
||||||
switch (dt) {
|
|
||||||
case DT_HALF:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_FLOAT16);
|
|
||||||
break;
|
|
||||||
case DT_FLOAT:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_FLOAT32);
|
|
||||||
break;
|
|
||||||
case DT_DOUBLE:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_FLOAT64);
|
|
||||||
break;
|
|
||||||
case DT_INT32:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_INT32);
|
|
||||||
break;
|
|
||||||
case DT_UINT32:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_UINT32);
|
|
||||||
break;
|
|
||||||
case DT_UINT8:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_UINT8);
|
|
||||||
break;
|
|
||||||
case DT_UINT16:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_UINT16);
|
|
||||||
break;
|
|
||||||
case DT_INT8:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_INT8);
|
|
||||||
break;
|
|
||||||
case DT_INT16:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_INT16);
|
|
||||||
break;
|
|
||||||
case DT_INT64:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_INT64);
|
|
||||||
break;
|
|
||||||
case DT_UINT64:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_UINT64);
|
|
||||||
break;
|
|
||||||
case DT_BOOL:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_BOOL);
|
|
||||||
break;
|
|
||||||
case DT_COMPLEX64:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_COMPLEX64);
|
|
||||||
break;
|
|
||||||
case DT_COMPLEX128:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_COMPLEX128);
|
|
||||||
break;
|
|
||||||
case DT_STRING:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_OBJECT);
|
|
||||||
break;
|
|
||||||
case DT_QINT8:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_INT8);
|
|
||||||
break;
|
|
||||||
case DT_QINT16:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_INT16);
|
|
||||||
break;
|
|
||||||
case DT_QINT32:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_INT32);
|
|
||||||
break;
|
|
||||||
case DT_QUINT8:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_UINT8);
|
|
||||||
break;
|
|
||||||
case DT_QUINT16:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_UINT16);
|
|
||||||
break;
|
|
||||||
case DT_RESOURCE:
|
|
||||||
*out_descr = PyArray_DescrFromType(NPY_UBYTE);
|
|
||||||
break;
|
|
||||||
case DT_BFLOAT16:
|
|
||||||
Py_INCREF(BFLOAT16_DESCR);
|
|
||||||
*out_descr = BFLOAT16_DESCR;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return errors::Internal("TensorFlow data type ", DataType_Name(dt),
|
|
||||||
" cannot be converted to a NumPy data type.");
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// NumPy defines fixed-width aliases for platform integer types. However,
|
|
||||||
// some types do not have a fixed-width alias. Specifically
|
|
||||||
//
|
|
||||||
// * on a LLP64 system NPY_INT32 == NPY_LONG therefore NPY_INT is not aliased;
|
|
||||||
// * on a LP64 system NPY_INT64 == NPY_LONG and NPY_LONGLONG is not aliased.
|
|
||||||
//
|
|
||||||
int MaybeResolveNumPyPlatformType(int type_num) {
|
|
||||||
switch (type_num) {
|
|
||||||
#if NPY_BITS_OF_INT == 32 && NPY_BITS_OF_LONGLONG == 32
|
|
||||||
case NPY_INT:
|
|
||||||
return NPY_INT32;
|
|
||||||
case NPY_UINT:
|
|
||||||
return NPY_UINT32;
|
|
||||||
#endif
|
|
||||||
#if NPY_BITSOF_INT == 32 && NPY_BITSOF_LONGLONG == 64
|
|
||||||
case NPY_LONGLONG:
|
|
||||||
return NPY_INT64;
|
|
||||||
case NPY_ULONGLONG:
|
|
||||||
return NPY_UINT64;
|
|
||||||
#endif
|
|
||||||
default:
|
|
||||||
return type_num;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt) {
|
|
||||||
const int type_num = MaybeResolveNumPyPlatformType(descr->type_num);
|
|
||||||
switch (type_num) {
|
|
||||||
case NPY_FLOAT16:
|
|
||||||
*out_dt = DT_HALF;
|
|
||||||
break;
|
|
||||||
case NPY_FLOAT32:
|
|
||||||
*out_dt = DT_FLOAT;
|
|
||||||
break;
|
|
||||||
case NPY_FLOAT64:
|
|
||||||
*out_dt = DT_DOUBLE;
|
|
||||||
break;
|
|
||||||
case NPY_INT8:
|
|
||||||
*out_dt = DT_INT8;
|
|
||||||
break;
|
|
||||||
case NPY_INT16:
|
|
||||||
*out_dt = DT_INT16;
|
|
||||||
break;
|
|
||||||
case NPY_INT32:
|
|
||||||
*out_dt = DT_INT32;
|
|
||||||
break;
|
|
||||||
case NPY_INT64:
|
|
||||||
*out_dt = DT_INT64;
|
|
||||||
break;
|
|
||||||
case NPY_UINT8:
|
|
||||||
*out_dt = DT_UINT8;
|
|
||||||
break;
|
|
||||||
case NPY_UINT16:
|
|
||||||
*out_dt = DT_UINT16;
|
|
||||||
break;
|
|
||||||
case NPY_UINT32:
|
|
||||||
*out_dt = DT_UINT32;
|
|
||||||
break;
|
|
||||||
case NPY_UINT64:
|
|
||||||
*out_dt = DT_UINT64;
|
|
||||||
break;
|
|
||||||
case NPY_BOOL:
|
|
||||||
*out_dt = DT_BOOL;
|
|
||||||
break;
|
|
||||||
case NPY_COMPLEX64:
|
|
||||||
*out_dt = DT_COMPLEX64;
|
|
||||||
break;
|
|
||||||
case NPY_COMPLEX128:
|
|
||||||
*out_dt = DT_COMPLEX128;
|
|
||||||
break;
|
|
||||||
case NPY_OBJECT:
|
|
||||||
case NPY_STRING:
|
|
||||||
case NPY_UNICODE:
|
|
||||||
*out_dt = DT_STRING;
|
|
||||||
break;
|
|
||||||
case NPY_VOID: {
|
|
||||||
if (descr == QINT8_DESCR) {
|
|
||||||
*out_dt = DT_QINT8;
|
|
||||||
break;
|
|
||||||
} else if (descr == QINT16_DESCR) {
|
|
||||||
*out_dt = DT_QINT16;
|
|
||||||
break;
|
|
||||||
} else if (descr == QINT32_DESCR) {
|
|
||||||
*out_dt = DT_QINT32;
|
|
||||||
break;
|
|
||||||
} else if (descr == QUINT8_DESCR) {
|
|
||||||
*out_dt = DT_QUINT8;
|
|
||||||
break;
|
|
||||||
} else if (descr == QUINT16_DESCR) {
|
|
||||||
*out_dt = DT_QUINT16;
|
|
||||||
break;
|
|
||||||
} else if (descr == RESOURCE_DESCR) {
|
|
||||||
*out_dt = DT_RESOURCE;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors::Internal("Unsupported NumPy struct data type: ",
|
|
||||||
PyArray_DescrReprAsString(descr));
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if (type_num == Bfloat16NumpyType()) {
|
|
||||||
*out_dt = DT_BFLOAT16;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors::Internal("Unregistered NumPy data type: ",
|
|
||||||
PyArray_DescrReprAsString(descr));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
@ -1,65 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
|
|
||||||
#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
|
|
||||||
|
|
||||||
// Must be included first.
|
|
||||||
// clang-format: off
|
|
||||||
#include "tensorflow/python/lib/core/numpy.h"
|
|
||||||
// clang-format: on
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
extern PyArray_Descr* QINT8_DESCR;
|
|
||||||
extern PyArray_Descr* QINT16_DESCR;
|
|
||||||
extern PyArray_Descr* QINT32_DESCR;
|
|
||||||
extern PyArray_Descr* QUINT8_DESCR;
|
|
||||||
extern PyArray_Descr* QUINT16_DESCR;
|
|
||||||
extern PyArray_Descr* RESOURCE_DESCR;
|
|
||||||
extern PyArray_Descr* BFLOAT16_DESCR;
|
|
||||||
|
|
||||||
// Register custom NumPy types.
|
|
||||||
//
|
|
||||||
// This function must be called in order to be able to map TensorFlow
|
|
||||||
// data types which do not have a corresponding standard NumPy data type
|
|
||||||
// (e.g. bfloat16 or qint8).
|
|
||||||
//
|
|
||||||
// TODO(b/144230631): The name is slightly misleading, as the function only
|
|
||||||
// registers bfloat16 and defines structured aliases for other data types
|
|
||||||
// (e.g. qint8).
|
|
||||||
void MaybeRegisterCustomNumPyTypes();
|
|
||||||
|
|
||||||
// Returns a NumPy data type matching a given tensorflow::DataType. If the
|
|
||||||
// function call succeeds, the caller is responsible for DECREF'ing the
|
|
||||||
// resulting PyArray_Descr*.
|
|
||||||
//
|
|
||||||
// NumPy does not support quantized integer types, so TensorFlow defines
|
|
||||||
// structured aliases for them, e.g. tf.qint8 is represented as
|
|
||||||
// np.dtype([("qint8", np.int8)]). However, for historical reasons this
|
|
||||||
// function does not use these aliases, and instead returns the *aliased*
|
|
||||||
// types (np.int8 in the example).
|
|
||||||
// TODO(b/144230631): Return an alias instead of the aliased type.
|
|
||||||
Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr);
|
|
||||||
|
|
||||||
// Returns a tensorflow::DataType corresponding to a given NumPy data type.
|
|
||||||
Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
|
|
@ -175,7 +175,7 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file):
|
|||||||
else:
|
else:
|
||||||
# If not a section header and not an empty line, then it's a symbol
|
# If not a section header and not an empty line, then it's a symbol
|
||||||
# line. e.g. `tensorflow::swig::IsSequence`
|
# line. e.g. `tensorflow::swig::IsSequence`
|
||||||
symbols[curr_lib].append(re.escape(line))
|
symbols[curr_lib].append(line)
|
||||||
|
|
||||||
lib_paths = []
|
lib_paths = []
|
||||||
with open(lib_paths_file, "r") as f:
|
with open(lib_paths_file, "r") as f:
|
||||||
|
@ -43,6 +43,10 @@ tensorflow::tfprof::SerializeToString
|
|||||||
[graph_analyzer_tool] # graph_analyzer
|
[graph_analyzer_tool] # graph_analyzer
|
||||||
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
|
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
|
||||||
|
|
||||||
|
[bfloat16_lib] # bfloat16
|
||||||
|
tensorflow::RegisterNumpyBfloat16
|
||||||
|
tensorflow::Bfloat16PyType
|
||||||
|
|
||||||
[events_writer] # events_writer
|
[events_writer] # events_writer
|
||||||
tensorflow::EventsWriter::Init
|
tensorflow::EventsWriter::Init
|
||||||
tensorflow::EventsWriter::InitWithSuffix
|
tensorflow::EventsWriter::InitWithSuffix
|
||||||
@ -185,13 +189,3 @@ tensorflow::Set_TF_Status_from_Status
|
|||||||
|
|
||||||
[context] # tfe
|
[context] # tfe
|
||||||
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||||
|
|
||||||
[ndarray_tensor_types] # _dtypes
|
|
||||||
tensorflow::MaybeRegisterCustomNumPyTypes
|
|
||||||
tensorflow::BFLOAT16_DESCR
|
|
||||||
tensorflow::QINT8_DESCR
|
|
||||||
tensorflow::QINT16_DESCR
|
|
||||||
tensorflow::QINT32_DESCR
|
|
||||||
tensorflow::QUINT8_DESCR
|
|
||||||
tensorflow::QUINT16_DESCR
|
|
||||||
tensorflow::RESOURCE_DESCR
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user