Moved custom NumPy data types to :ndarray_tensor_types
Prior to this change qint*, quint* and resource were defined in Python as a single-field struct arrays, e.g. _np_qint8 = np.dtype([("qint8", np.int8)]) Several TensorFlow functions had to special-case struct arrays (which otherwise have type NPY_VOID) and infer the real dtype from struct fields. Having these data types defined and handled in C++ allows to minimize magic on the Python/C++ boundary. Note 1: The defined data types are *not* registered because NumPy requires every registered data type to also register casting functions and UFunc specializations which will require significantly more code. Note 2: Tensor->NumPy conversion for qint*, quint* and resource dtypes does not use the defined types, so tf.constant([42], dtype=tf.qint8).numpy() returns an array of dtype tf.int8. This behavior is unaffected by this change. PiperOrigin-RevId: 286392791 Change-Id: I0bdc55c2002195bca6273c94d3965d6620239985
This commit is contained in:
parent
69111e174c
commit
40dab8918d
@ -394,8 +394,6 @@ cc_library(
|
||||
srcs = ["lib/core/numpy.cc"],
|
||||
hdrs = ["lib/core/numpy.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
@ -407,24 +405,12 @@ cc_library(
|
||||
hdrs = ["lib/core/bfloat16.h"],
|
||||
deps = [
|
||||
":numpy_lib",
|
||||
":safe_ptr",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_bfloat16",
|
||||
srcs = ["lib/core/bfloat16_wrapper.cc"],
|
||||
hdrs = ["lib/core/bfloat16.h"],
|
||||
module_name = "_pywrap_bfloat16",
|
||||
deps = [
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ndarray_tensor_bridge",
|
||||
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
|
||||
@ -435,7 +421,7 @@ cc_library(
|
||||
],
|
||||
),
|
||||
deps = [
|
||||
":bfloat16_lib",
|
||||
":ndarray_tensor_types",
|
||||
":numpy_lib",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
@ -796,6 +782,31 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ndarray_tensor_types",
|
||||
srcs = ["lib/core/ndarray_tensor_types.cc"],
|
||||
hdrs = ["lib/core/ndarray_tensor_types.h"],
|
||||
deps = [
|
||||
":bfloat16_lib",
|
||||
":numpy_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ndarray_tensor_types_headers_lib",
|
||||
hdrs = ["lib/core/ndarray_tensor_types.h"],
|
||||
deps = [
|
||||
":numpy_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ndarray_tensor",
|
||||
srcs = ["lib/core/ndarray_tensor.cc"],
|
||||
@ -804,8 +815,8 @@ cc_library(
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
]),
|
||||
deps = [
|
||||
":bfloat16_lib",
|
||||
":ndarray_tensor_bridge",
|
||||
":ndarray_tensor_types",
|
||||
":numpy_lib",
|
||||
":safe_ptr",
|
||||
"//tensorflow/c:c_api",
|
||||
@ -1165,6 +1176,7 @@ tf_python_pybind_extension(
|
||||
srcs = ["framework/dtypes.cc"],
|
||||
module_name = "_dtypes",
|
||||
deps = [
|
||||
":ndarray_tensor_types_headers_lib",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
@ -1178,7 +1190,6 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_dtypes",
|
||||
":_pywrap_bfloat16",
|
||||
":pywrap_tensorflow",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
@ -5507,7 +5518,6 @@ tf_py_wrap_cc(
|
||||
"//conditions:default": None,
|
||||
}),
|
||||
deps = [
|
||||
":bfloat16_lib",
|
||||
":cost_analyzer_lib",
|
||||
":model_analyzer_lib",
|
||||
":cpp_python_util",
|
||||
@ -5577,7 +5587,8 @@ WIN_LIB_FILES_FOR_EXPORTED_SYMBOLS = [
|
||||
":numpy_lib", # checkpoint_reader
|
||||
":safe_ptr", # checkpoint_reader
|
||||
":python_op_gen", # python_op_gen
|
||||
":bfloat16_lib", # bfloat16
|
||||
":bfloat16_lib", # _dtypes
|
||||
":ndarray_tensor_types", # _dtypes
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
|
||||
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
|
||||
@ -6204,7 +6215,6 @@ cuda_py_test(
|
||||
":client_testlib",
|
||||
":constant_op",
|
||||
":dtypes",
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_ops",
|
||||
":training",
|
||||
":variable_scope",
|
||||
|
@ -156,7 +156,6 @@ py_test(
|
||||
deps = [
|
||||
":convert",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
@ -44,6 +44,7 @@ cc_library(
|
||||
"//tensorflow/python:cpp_python_util",
|
||||
"//tensorflow/python:ndarray_tensor",
|
||||
"//tensorflow/python:ndarray_tensor_bridge",
|
||||
"//tensorflow/python:ndarray_tensor_types",
|
||||
"//tensorflow/python:numpy_lib",
|
||||
"//tensorflow/python:py_seq_tensor",
|
||||
"//tensorflow/python:safe_ptr",
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
@ -288,15 +289,15 @@ TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
||||
if (PyArray_Check(value)) {
|
||||
int desired_np_dtype = -1;
|
||||
if (dtype != tensorflow::DT_INVALID) {
|
||||
if (!tensorflow::TF_DataType_to_PyArray_TYPE(
|
||||
static_cast<TF_DataType>(dtype), &desired_np_dtype)
|
||||
.ok()) {
|
||||
PyArray_Descr* descr = nullptr;
|
||||
if (!tensorflow::DataTypeToPyArray_Descr(dtype, &descr).ok()) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat("Invalid dtype argument value ", dtype)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
desired_np_dtype = descr->type_num;
|
||||
}
|
||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
||||
int current_np_dtype = PyArray_TYPE(array);
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@ -60,6 +61,18 @@ inline bool DataTypeIsNumPyCompatible(DataType dt) {
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_dtypes, m) {
|
||||
tensorflow::MaybeRegisterCustomNumPyTypes();
|
||||
|
||||
m.attr("np_bfloat16") =
|
||||
reinterpret_cast<PyObject*>(tensorflow::BFLOAT16_DESCR);
|
||||
m.attr("np_qint8") = reinterpret_cast<PyObject*>(tensorflow::QINT8_DESCR);
|
||||
m.attr("np_qint16") = reinterpret_cast<PyObject*>(tensorflow::QINT16_DESCR);
|
||||
m.attr("np_qint32") = reinterpret_cast<PyObject*>(tensorflow::QINT32_DESCR);
|
||||
m.attr("np_quint8") = reinterpret_cast<PyObject*>(tensorflow::QUINT8_DESCR);
|
||||
m.attr("np_quint16") = reinterpret_cast<PyObject*>(tensorflow::QUINT16_DESCR);
|
||||
m.attr("np_resource") =
|
||||
reinterpret_cast<PyObject*>(tensorflow::RESOURCE_DESCR);
|
||||
|
||||
py::class_<tensorflow::DataType>(m, "DType")
|
||||
.def(py::init([](py::object obj) {
|
||||
auto id = static_cast<int>(py::int_(obj));
|
||||
|
@ -20,17 +20,16 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import builtins
|
||||
|
||||
# TODO(b/143110113): This import has to come first. This is a temporary
|
||||
# workaround which fixes repeated proto registration on macOS.
|
||||
# pylint: disable=g-bad-import-order, unused-import
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
# pylint: enable=g-bad-import-order, unused-import
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
|
||||
# protobuf errors where a file is defined twice on MacOS.
|
||||
# pylint: disable=invalid-import-order,g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
from tensorflow.python import _pywrap_bfloat16
|
||||
from tensorflow.python import _dtypes
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||
|
||||
|
||||
# pylint: disable=slots-on-old-class
|
||||
@tf_export("dtypes.DType", "DType")
|
||||
@ -425,20 +424,18 @@ _STRING_TO_TF["double_ref"] = float64_ref
|
||||
|
||||
# Numpy representation for quantized dtypes.
|
||||
#
|
||||
# These are magic strings that are used in the swig wrapper to identify
|
||||
# quantized types.
|
||||
# TODO(mrry,keveman): Investigate Numpy type registration to replace this
|
||||
# hard-coding of names.
|
||||
_np_qint8 = np.dtype([("qint8", np.int8)])
|
||||
_np_quint8 = np.dtype([("quint8", np.uint8)])
|
||||
_np_qint16 = np.dtype([("qint16", np.int16)])
|
||||
_np_quint16 = np.dtype([("quint16", np.uint16)])
|
||||
_np_qint32 = np.dtype([("qint32", np.int32)])
|
||||
_np_qint8 = _dtypes.np_qint8
|
||||
_np_qint16 = _dtypes.np_qint16
|
||||
_np_qint32 = _dtypes.np_qint32
|
||||
_np_quint8 = _dtypes.np_quint8
|
||||
_np_quint16 = _dtypes.np_quint16
|
||||
|
||||
# _np_bfloat16 is defined by a module import.
|
||||
# Technically, _np_bfloat does not have to be a Python class, but existing
|
||||
# code expects it to.
|
||||
_np_bfloat16 = _dtypes.np_bfloat16.type
|
||||
|
||||
# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
|
||||
np_resource = np.dtype([("resource", np.ubyte)])
|
||||
np_resource = _dtypes.np_resource
|
||||
|
||||
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
|
||||
_NP_TO_TF = {
|
||||
|
@ -321,7 +321,7 @@ class PyFuncTest(PyFuncTestBase):
|
||||
y, = script_ops.py_func(bad, [], [dtypes.float32])
|
||||
|
||||
with self.assertRaisesRegexp(errors.InternalError,
|
||||
"Unsupported numpy data type"):
|
||||
"Unsupported NumPy struct data type"):
|
||||
self.evaluate(y)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
|
@ -21,11 +21,19 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
struct PyDecrefDeleter {
|
||||
void operator()(PyObject* p) const { Py_DECREF(p); }
|
||||
};
|
||||
|
||||
using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
|
||||
Safe_PyObjectPtr make_safe(PyObject* object) {
|
||||
return Safe_PyObjectPtr(object);
|
||||
}
|
||||
|
||||
// Workarounds for Python 2 vs 3 API differences.
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
|
||||
|
@ -24,12 +24,10 @@ import math
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=unused-import,g-bad-import-order
|
||||
from tensorflow.python import _pywrap_bfloat16
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
|
||||
bfloat16 = dtypes._np_bfloat16 # pylint: disable=protected-access
|
||||
|
||||
|
||||
class Bfloat16Test(test.TestCase):
|
||||
|
@ -1,24 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
|
||||
PYBIND11_MODULE(_pywrap_bfloat16, m) {
|
||||
tensorflow::RegisterNumpyBfloat16();
|
||||
|
||||
m.def("TF_bfloat16_type",
|
||||
[] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
|
||||
}
|
@ -21,171 +21,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
char const* numpy_type_name(int numpy_type) {
|
||||
switch (numpy_type) {
|
||||
#define TYPE_CASE(s) \
|
||||
case s: \
|
||||
return #s
|
||||
|
||||
TYPE_CASE(NPY_BOOL);
|
||||
TYPE_CASE(NPY_BYTE);
|
||||
TYPE_CASE(NPY_UBYTE);
|
||||
TYPE_CASE(NPY_SHORT);
|
||||
TYPE_CASE(NPY_USHORT);
|
||||
TYPE_CASE(NPY_INT);
|
||||
TYPE_CASE(NPY_UINT);
|
||||
TYPE_CASE(NPY_LONG);
|
||||
TYPE_CASE(NPY_ULONG);
|
||||
TYPE_CASE(NPY_LONGLONG);
|
||||
TYPE_CASE(NPY_ULONGLONG);
|
||||
TYPE_CASE(NPY_FLOAT);
|
||||
TYPE_CASE(NPY_DOUBLE);
|
||||
TYPE_CASE(NPY_LONGDOUBLE);
|
||||
TYPE_CASE(NPY_CFLOAT);
|
||||
TYPE_CASE(NPY_CDOUBLE);
|
||||
TYPE_CASE(NPY_CLONGDOUBLE);
|
||||
TYPE_CASE(NPY_OBJECT);
|
||||
TYPE_CASE(NPY_STRING);
|
||||
TYPE_CASE(NPY_UNICODE);
|
||||
TYPE_CASE(NPY_VOID);
|
||||
TYPE_CASE(NPY_DATETIME);
|
||||
TYPE_CASE(NPY_TIMEDELTA);
|
||||
TYPE_CASE(NPY_HALF);
|
||||
TYPE_CASE(NPY_NTYPES);
|
||||
TYPE_CASE(NPY_NOTYPE);
|
||||
TYPE_CASE(NPY_CHAR);
|
||||
TYPE_CASE(NPY_USERDEF);
|
||||
default:
|
||||
return "not a numpy type";
|
||||
}
|
||||
}
|
||||
|
||||
Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
|
||||
TF_DataType* out_tf_datatype) {
|
||||
PyObject* key;
|
||||
PyObject* value;
|
||||
Py_ssize_t pos = 0;
|
||||
if (PyDict_Next(descr->fields, &pos, &key, &value)) {
|
||||
// In Python 3, the keys of numpy custom struct types are unicode, unlike
|
||||
// Python 2, where the keys are bytes.
|
||||
const char* key_string =
|
||||
PyBytes_Check(key) ? PyBytes_AsString(key)
|
||||
: PyBytes_AsString(PyUnicode_AsASCIIString(key));
|
||||
if (!key_string) {
|
||||
return errors::Internal("Corrupt numpy type descriptor");
|
||||
}
|
||||
tensorflow::string key = key_string;
|
||||
// The typenames here should match the field names in the custom struct
|
||||
// types constructed in test_util.py.
|
||||
// TODO(mrry,keveman): Investigate Numpy type registration to replace this
|
||||
// hard-coding of names.
|
||||
if (key == "quint8") {
|
||||
*out_tf_datatype = TF_QUINT8;
|
||||
} else if (key == "qint8") {
|
||||
*out_tf_datatype = TF_QINT8;
|
||||
} else if (key == "qint16") {
|
||||
*out_tf_datatype = TF_QINT16;
|
||||
} else if (key == "quint16") {
|
||||
*out_tf_datatype = TF_QUINT16;
|
||||
} else if (key == "qint32") {
|
||||
*out_tf_datatype = TF_QINT32;
|
||||
} else if (key == "resource") {
|
||||
*out_tf_datatype = TF_RESOURCE;
|
||||
} else {
|
||||
return errors::Internal("Unsupported numpy data type");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::Internal("Unsupported numpy data type");
|
||||
}
|
||||
|
||||
Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
|
||||
TF_DataType* out_tf_datatype) {
|
||||
int pyarray_type = PyArray_TYPE(array);
|
||||
PyArray_Descr* descr = PyArray_DESCR(array);
|
||||
switch (pyarray_type) {
|
||||
case NPY_FLOAT16:
|
||||
*out_tf_datatype = TF_HALF;
|
||||
break;
|
||||
case NPY_FLOAT32:
|
||||
*out_tf_datatype = TF_FLOAT;
|
||||
break;
|
||||
case NPY_FLOAT64:
|
||||
*out_tf_datatype = TF_DOUBLE;
|
||||
break;
|
||||
case NPY_INT32:
|
||||
*out_tf_datatype = TF_INT32;
|
||||
break;
|
||||
case NPY_UINT8:
|
||||
*out_tf_datatype = TF_UINT8;
|
||||
break;
|
||||
case NPY_UINT16:
|
||||
*out_tf_datatype = TF_UINT16;
|
||||
break;
|
||||
case NPY_UINT32:
|
||||
*out_tf_datatype = TF_UINT32;
|
||||
break;
|
||||
case NPY_UINT64:
|
||||
*out_tf_datatype = TF_UINT64;
|
||||
break;
|
||||
case NPY_INT8:
|
||||
*out_tf_datatype = TF_INT8;
|
||||
break;
|
||||
case NPY_INT16:
|
||||
*out_tf_datatype = TF_INT16;
|
||||
break;
|
||||
case NPY_INT64:
|
||||
*out_tf_datatype = TF_INT64;
|
||||
break;
|
||||
case NPY_BOOL:
|
||||
*out_tf_datatype = TF_BOOL;
|
||||
break;
|
||||
case NPY_COMPLEX64:
|
||||
*out_tf_datatype = TF_COMPLEX64;
|
||||
break;
|
||||
case NPY_COMPLEX128:
|
||||
*out_tf_datatype = TF_COMPLEX128;
|
||||
break;
|
||||
case NPY_OBJECT:
|
||||
case NPY_STRING:
|
||||
case NPY_UNICODE:
|
||||
*out_tf_datatype = TF_STRING;
|
||||
break;
|
||||
case NPY_VOID:
|
||||
// Quantized types are currently represented as custom struct types.
|
||||
// PyArray_TYPE returns NPY_VOID for structs, and we should look into
|
||||
// descr to derive the actual type.
|
||||
// Direct feeds of certain types of ResourceHandles are represented as a
|
||||
// custom struct type.
|
||||
return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
|
||||
default:
|
||||
if (pyarray_type == Bfloat16NumpyType()) {
|
||||
*out_tf_datatype = TF_BFLOAT16;
|
||||
break;
|
||||
} else if (pyarray_type == NPY_ULONGLONG) {
|
||||
// NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values
|
||||
// might be different on certain platforms.
|
||||
*out_tf_datatype = TF_UINT64;
|
||||
break;
|
||||
} else if (pyarray_type == NPY_LONGLONG) {
|
||||
// NPY_LONGLONG is equivalent to NPY_INT64, while their enum values
|
||||
// might be different on certain platforms.
|
||||
*out_tf_datatype = TF_INT64;
|
||||
break;
|
||||
}
|
||||
return errors::Internal("Unsupported numpy type: ",
|
||||
numpy_type_name(pyarray_type));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
|
||||
PyObject** ptr_owner) {
|
||||
*ptr_owner = nullptr;
|
||||
@ -344,38 +186,6 @@ Status GetPyArrayDimensionsForTensor(const TF_Tensor* tensor,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Determine the type description (PyArray_Descr) of a numpy ndarray to be
|
||||
// created to represent an output Tensor.
|
||||
Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
|
||||
PyArray_Descr** descr) {
|
||||
if (TF_TensorType(tensor) == TF_RESOURCE) {
|
||||
PyObject* field = PyTuple_New(3);
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyTuple_SetItem(field, 0, PyBytes_FromString("resource"));
|
||||
#else
|
||||
PyTuple_SetItem(field, 0, PyUnicode_FromString("resource"));
|
||||
#endif
|
||||
PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE));
|
||||
PyTuple_SetItem(field, 2, PyLong_FromLong(1));
|
||||
PyObject* fields = PyList_New(1);
|
||||
PyList_SetItem(fields, 0, field);
|
||||
int convert_result = PyArray_DescrConverter(fields, descr);
|
||||
Py_CLEAR(field);
|
||||
Py_CLEAR(fields);
|
||||
if (convert_result != 1) {
|
||||
return errors::Internal("Failed to create numpy array description for ",
|
||||
"TF_RESOURCE-type tensor");
|
||||
}
|
||||
} else {
|
||||
int type_num = -1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
|
||||
*descr = PyArray_DescrFromType(type_num);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline void FastMemcpy(void* dst, const void* src, size_t size) {
|
||||
// clang-format off
|
||||
switch (size) {
|
||||
@ -461,7 +271,8 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||
|
||||
// Copy the TF_TensorData into a newly-created ndarray and return it.
|
||||
PyArray_Descr* descr = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
|
||||
TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(
|
||||
static_cast<DataType>(TF_TensorType(tensor.get())), &descr));
|
||||
Safe_PyObjectPtr safe_out_array =
|
||||
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
|
||||
if (!safe_out_array) {
|
||||
@ -499,7 +310,11 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
||||
|
||||
// Convert numpy dtype to TensorFlow dtype.
|
||||
TF_DataType dtype = TF_FLOAT;
|
||||
TF_RETURN_IF_ERROR(PyArray_TYPE_to_TF_DataType(array, &dtype));
|
||||
{
|
||||
DataType tmp;
|
||||
TF_RETURN_IF_ERROR(PyArray_DescrToDataType(PyArray_DESCR(array), &tmp));
|
||||
dtype = static_cast<TF_DataType>(tmp);
|
||||
}
|
||||
|
||||
tensorflow::int64 nelems = 1;
|
||||
gtl::InlinedVector<int64_t, 4> dims;
|
||||
|
@ -13,16 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Must be included first.
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// Must be included first.
|
||||
// clang-format: off
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
// clang-format: on
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -107,85 +110,6 @@ PyTypeObject TensorReleaserType = {
|
||||
nullptr, /* tp_richcompare */
|
||||
};
|
||||
|
||||
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
|
||||
int* out_pyarray_type) {
|
||||
switch (tf_datatype) {
|
||||
case TF_HALF:
|
||||
*out_pyarray_type = NPY_FLOAT16;
|
||||
break;
|
||||
case TF_FLOAT:
|
||||
*out_pyarray_type = NPY_FLOAT32;
|
||||
break;
|
||||
case TF_DOUBLE:
|
||||
*out_pyarray_type = NPY_FLOAT64;
|
||||
break;
|
||||
case TF_INT32:
|
||||
*out_pyarray_type = NPY_INT32;
|
||||
break;
|
||||
case TF_UINT32:
|
||||
*out_pyarray_type = NPY_UINT32;
|
||||
break;
|
||||
case TF_UINT8:
|
||||
*out_pyarray_type = NPY_UINT8;
|
||||
break;
|
||||
case TF_UINT16:
|
||||
*out_pyarray_type = NPY_UINT16;
|
||||
break;
|
||||
case TF_INT8:
|
||||
*out_pyarray_type = NPY_INT8;
|
||||
break;
|
||||
case TF_INT16:
|
||||
*out_pyarray_type = NPY_INT16;
|
||||
break;
|
||||
case TF_INT64:
|
||||
*out_pyarray_type = NPY_INT64;
|
||||
break;
|
||||
case TF_UINT64:
|
||||
*out_pyarray_type = NPY_UINT64;
|
||||
break;
|
||||
case TF_BOOL:
|
||||
*out_pyarray_type = NPY_BOOL;
|
||||
break;
|
||||
case TF_COMPLEX64:
|
||||
*out_pyarray_type = NPY_COMPLEX64;
|
||||
break;
|
||||
case TF_COMPLEX128:
|
||||
*out_pyarray_type = NPY_COMPLEX128;
|
||||
break;
|
||||
case TF_STRING:
|
||||
*out_pyarray_type = NPY_OBJECT;
|
||||
break;
|
||||
case TF_RESOURCE:
|
||||
*out_pyarray_type = NPY_VOID;
|
||||
break;
|
||||
// TODO(keveman): These should be changed to NPY_VOID, and the type used for
|
||||
// the resulting numpy array should be the custom struct types that we
|
||||
// expect for quantized types.
|
||||
case TF_QINT8:
|
||||
*out_pyarray_type = NPY_INT8;
|
||||
break;
|
||||
case TF_QUINT8:
|
||||
*out_pyarray_type = NPY_UINT8;
|
||||
break;
|
||||
case TF_QINT16:
|
||||
*out_pyarray_type = NPY_INT16;
|
||||
break;
|
||||
case TF_QUINT16:
|
||||
*out_pyarray_type = NPY_UINT16;
|
||||
break;
|
||||
case TF_QINT32:
|
||||
*out_pyarray_type = NPY_INT32;
|
||||
break;
|
||||
case TF_BFLOAT16:
|
||||
*out_pyarray_type = Bfloat16NumpyType();
|
||||
break;
|
||||
default:
|
||||
return errors::Internal("Tensorflow type ", tf_datatype,
|
||||
" not convertible to numpy dtype.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
||||
std::function<void()> destructor, PyObject** result) {
|
||||
if (dtype == DT_STRING || dtype == DT_RESOURCE) {
|
||||
@ -193,15 +117,11 @@ Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
||||
"Cannot convert string or resource Tensors.");
|
||||
}
|
||||
|
||||
int type_num = -1;
|
||||
Status s =
|
||||
TF_DataType_to_PyArray_TYPE(static_cast<TF_DataType>(dtype), &type_num);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
|
||||
PyArray_Descr* descr = nullptr;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(dtype, &descr));
|
||||
auto* np_array = reinterpret_cast<PyArrayObject*>(
|
||||
PyArray_SimpleNewFromData(dim_size, dims, type_num, data));
|
||||
PyArray_SimpleNewFromData(dim_size, dims, descr->type_num, data));
|
||||
CHECK_NE(np_array, nullptr);
|
||||
PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA);
|
||||
if (PyType_Ready(&TensorReleaserType) == -1) {
|
||||
return errors::Unknown("Python type initialization failed.");
|
||||
|
@ -42,10 +42,6 @@ void ClearDecrefCache();
|
||||
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
|
||||
std::function<void()> destructor, PyObject** result);
|
||||
|
||||
// Converts TF_DataType to the corresponding numpy type.
|
||||
Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
|
||||
int* out_pyarray_type);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_BRIDGE_H_
|
||||
|
287
tensorflow/python/lib/core/ndarray_tensor_types.cc
Normal file
287
tensorflow/python/lib/core/ndarray_tensor_types.cc
Normal file
@ -0,0 +1,287 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
// Must be included first.
|
||||
// clang-format: off
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
// clang-format: on
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
PyArray_Descr* BFLOAT16_DESCR = nullptr;
|
||||
PyArray_Descr* QINT8_DESCR = nullptr;
|
||||
PyArray_Descr* QINT16_DESCR = nullptr;
|
||||
PyArray_Descr* QINT32_DESCR = nullptr;
|
||||
PyArray_Descr* QUINT8_DESCR = nullptr;
|
||||
PyArray_Descr* QUINT16_DESCR = nullptr;
|
||||
PyArray_Descr* RESOURCE_DESCR = nullptr;
|
||||
|
||||
// Define a struct array data type `[(tag, type_num)]`.
|
||||
PyArray_Descr* DefineStructTypeAlias(const char* tag, int type_num) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
auto* py_tag = PyBytes_FromString(tag);
|
||||
#else
|
||||
auto* py_tag = PyUnicode_FromString(tag);
|
||||
#endif
|
||||
auto* descr = PyArray_DescrFromType(type_num);
|
||||
auto* py_tag_and_descr = PyTuple_Pack(2, py_tag, descr);
|
||||
auto* obj = PyList_New(1);
|
||||
PyList_SetItem(obj, 0, py_tag_and_descr);
|
||||
PyArray_Descr* alias_descr = nullptr;
|
||||
// TODO(slebedev): Switch to PyArray_DescrNewFromType because struct
|
||||
// array dtypes could not be used for scalars. Note that this will
|
||||
// require registering type conversions and UFunc specializations.
|
||||
// See b/144230631.
|
||||
CHECK_EQ(PyArray_DescrConverter(obj, &alias_descr), NPY_SUCCEED);
|
||||
Py_DECREF(obj);
|
||||
Py_DECREF(py_tag_and_descr);
|
||||
Py_DECREF(py_tag);
|
||||
Py_DECREF(descr);
|
||||
CHECK_NE(alias_descr, nullptr);
|
||||
return alias_descr;
|
||||
}
|
||||
|
||||
void MaybeRegisterCustomNumPyTypes() {
|
||||
static bool registered = false;
|
||||
if (registered) return;
|
||||
ImportNumpy(); // Ensure NumPy is loaded.
|
||||
// Make sure the tags are consistent with DataTypeToPyArray_Descr.
|
||||
QINT8_DESCR = DefineStructTypeAlias("qint8", NPY_INT8);
|
||||
QINT16_DESCR = DefineStructTypeAlias("qint16", NPY_INT16);
|
||||
QINT32_DESCR = DefineStructTypeAlias("qint32", NPY_INT32);
|
||||
QUINT8_DESCR = DefineStructTypeAlias("quint8", NPY_UINT8);
|
||||
QUINT16_DESCR = DefineStructTypeAlias("quint16", NPY_UINT16);
|
||||
RESOURCE_DESCR = DefineStructTypeAlias("resource", NPY_UBYTE);
|
||||
RegisterNumpyBfloat16();
|
||||
BFLOAT16_DESCR = PyArray_DescrFromType(Bfloat16NumpyType());
|
||||
registered = true;
|
||||
}
|
||||
|
||||
const char* PyArray_DescrReprAsString(PyArray_Descr* descr) {
|
||||
auto* descr_repr = PyObject_Repr(reinterpret_cast<PyObject*>(descr));
|
||||
const char* result;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
result = PyBytes_AsString(descr_repr);
|
||||
#else
|
||||
auto* tmp = PyUnicode_AsASCIIString(descr_repr);
|
||||
result = PyBytes_AsString(tmp);
|
||||
Py_DECREF(tmp);
|
||||
#endif
|
||||
|
||||
Py_DECREF(descr_repr);
|
||||
return result;
|
||||
}
|
||||
|
||||
Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr) {
|
||||
switch (dt) {
|
||||
case DT_HALF:
|
||||
*out_descr = PyArray_DescrFromType(NPY_FLOAT16);
|
||||
break;
|
||||
case DT_FLOAT:
|
||||
*out_descr = PyArray_DescrFromType(NPY_FLOAT32);
|
||||
break;
|
||||
case DT_DOUBLE:
|
||||
*out_descr = PyArray_DescrFromType(NPY_FLOAT64);
|
||||
break;
|
||||
case DT_INT32:
|
||||
*out_descr = PyArray_DescrFromType(NPY_INT32);
|
||||
break;
|
||||
case DT_UINT32:
|
||||
*out_descr = PyArray_DescrFromType(NPY_UINT32);
|
||||
break;
|
||||
case DT_UINT8:
|
||||
*out_descr = PyArray_DescrFromType(NPY_UINT8);
|
||||
break;
|
||||
case DT_UINT16:
|
||||
*out_descr = PyArray_DescrFromType(NPY_UINT16);
|
||||
break;
|
||||
case DT_INT8:
|
||||
*out_descr = PyArray_DescrFromType(NPY_INT8);
|
||||
break;
|
||||
case DT_INT16:
|
||||
*out_descr = PyArray_DescrFromType(NPY_INT16);
|
||||
break;
|
||||
case DT_INT64:
|
||||
*out_descr = PyArray_DescrFromType(NPY_INT64);
|
||||
break;
|
||||
case DT_UINT64:
|
||||
*out_descr = PyArray_DescrFromType(NPY_UINT64);
|
||||
break;
|
||||
case DT_BOOL:
|
||||
*out_descr = PyArray_DescrFromType(NPY_BOOL);
|
||||
break;
|
||||
case DT_COMPLEX64:
|
||||
*out_descr = PyArray_DescrFromType(NPY_COMPLEX64);
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
*out_descr = PyArray_DescrFromType(NPY_COMPLEX128);
|
||||
break;
|
||||
case DT_STRING:
|
||||
*out_descr = PyArray_DescrFromType(NPY_OBJECT);
|
||||
break;
|
||||
case DT_QINT8:
|
||||
*out_descr = PyArray_DescrFromType(NPY_INT8);
|
||||
break;
|
||||
case DT_QINT16:
|
||||
*out_descr = PyArray_DescrFromType(NPY_INT16);
|
||||
break;
|
||||
case DT_QINT32:
|
||||
*out_descr = PyArray_DescrFromType(NPY_INT32);
|
||||
break;
|
||||
case DT_QUINT8:
|
||||
*out_descr = PyArray_DescrFromType(NPY_UINT8);
|
||||
break;
|
||||
case DT_QUINT16:
|
||||
*out_descr = PyArray_DescrFromType(NPY_UINT16);
|
||||
break;
|
||||
case DT_RESOURCE:
|
||||
*out_descr = PyArray_DescrFromType(NPY_UBYTE);
|
||||
break;
|
||||
case DT_BFLOAT16:
|
||||
Py_INCREF(BFLOAT16_DESCR);
|
||||
*out_descr = BFLOAT16_DESCR;
|
||||
break;
|
||||
default:
|
||||
return errors::Internal("TensorFlow data type ", DataType_Name(dt),
|
||||
" cannot be converted to a NumPy data type.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// NumPy defines fixed-width aliases for platform integer types. However,
|
||||
// some types do not have a fixed-width alias. Specifically
|
||||
//
|
||||
// * on a LLP64 system NPY_INT32 == NPY_LONG therefore NPY_INT is not aliased;
|
||||
// * on a LP64 system NPY_INT64 == NPY_LONG and NPY_LONGLONG is not aliased.
|
||||
//
|
||||
int MaybeResolveNumPyPlatformType(int type_num) {
|
||||
switch (type_num) {
|
||||
#if NPY_BITS_OF_INT == 32 && NPY_BITS_OF_LONGLONG == 32
|
||||
case NPY_INT:
|
||||
return NPY_INT32;
|
||||
case NPY_UINT:
|
||||
return NPY_UINT32;
|
||||
#endif
|
||||
#if NPY_BITSOF_INT == 32 && NPY_BITSOF_LONGLONG == 64
|
||||
case NPY_LONGLONG:
|
||||
return NPY_INT64;
|
||||
case NPY_ULONGLONG:
|
||||
return NPY_UINT64;
|
||||
#endif
|
||||
default:
|
||||
return type_num;
|
||||
}
|
||||
}
|
||||
|
||||
Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt) {
|
||||
const int type_num = MaybeResolveNumPyPlatformType(descr->type_num);
|
||||
switch (type_num) {
|
||||
case NPY_FLOAT16:
|
||||
*out_dt = DT_HALF;
|
||||
break;
|
||||
case NPY_FLOAT32:
|
||||
*out_dt = DT_FLOAT;
|
||||
break;
|
||||
case NPY_FLOAT64:
|
||||
*out_dt = DT_DOUBLE;
|
||||
break;
|
||||
case NPY_INT8:
|
||||
*out_dt = DT_INT8;
|
||||
break;
|
||||
case NPY_INT16:
|
||||
*out_dt = DT_INT16;
|
||||
break;
|
||||
case NPY_INT32:
|
||||
*out_dt = DT_INT32;
|
||||
break;
|
||||
case NPY_INT64:
|
||||
*out_dt = DT_INT64;
|
||||
break;
|
||||
case NPY_UINT8:
|
||||
*out_dt = DT_UINT8;
|
||||
break;
|
||||
case NPY_UINT16:
|
||||
*out_dt = DT_UINT16;
|
||||
break;
|
||||
case NPY_UINT32:
|
||||
*out_dt = DT_UINT32;
|
||||
break;
|
||||
case NPY_UINT64:
|
||||
*out_dt = DT_UINT64;
|
||||
break;
|
||||
case NPY_BOOL:
|
||||
*out_dt = DT_BOOL;
|
||||
break;
|
||||
case NPY_COMPLEX64:
|
||||
*out_dt = DT_COMPLEX64;
|
||||
break;
|
||||
case NPY_COMPLEX128:
|
||||
*out_dt = DT_COMPLEX128;
|
||||
break;
|
||||
case NPY_OBJECT:
|
||||
case NPY_STRING:
|
||||
case NPY_UNICODE:
|
||||
*out_dt = DT_STRING;
|
||||
break;
|
||||
case NPY_VOID: {
|
||||
if (descr == QINT8_DESCR) {
|
||||
*out_dt = DT_QINT8;
|
||||
break;
|
||||
} else if (descr == QINT16_DESCR) {
|
||||
*out_dt = DT_QINT16;
|
||||
break;
|
||||
} else if (descr == QINT32_DESCR) {
|
||||
*out_dt = DT_QINT32;
|
||||
break;
|
||||
} else if (descr == QUINT8_DESCR) {
|
||||
*out_dt = DT_QUINT8;
|
||||
break;
|
||||
} else if (descr == QUINT16_DESCR) {
|
||||
*out_dt = DT_QUINT16;
|
||||
break;
|
||||
} else if (descr == RESOURCE_DESCR) {
|
||||
*out_dt = DT_RESOURCE;
|
||||
break;
|
||||
}
|
||||
|
||||
return errors::Internal("Unsupported NumPy struct data type: ",
|
||||
PyArray_DescrReprAsString(descr));
|
||||
}
|
||||
default:
|
||||
if (type_num == Bfloat16NumpyType()) {
|
||||
*out_dt = DT_BFLOAT16;
|
||||
break;
|
||||
}
|
||||
|
||||
return errors::Internal("Unregistered NumPy data type: ",
|
||||
PyArray_DescrReprAsString(descr));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
65
tensorflow/python/lib/core/ndarray_tensor_types.h
Normal file
65
tensorflow/python/lib/core/ndarray_tensor_types.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
|
||||
#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
|
||||
|
||||
// Must be included first.
|
||||
// clang-format: off
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
// clang-format: on
|
||||
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
extern PyArray_Descr* QINT8_DESCR;
|
||||
extern PyArray_Descr* QINT16_DESCR;
|
||||
extern PyArray_Descr* QINT32_DESCR;
|
||||
extern PyArray_Descr* QUINT8_DESCR;
|
||||
extern PyArray_Descr* QUINT16_DESCR;
|
||||
extern PyArray_Descr* RESOURCE_DESCR;
|
||||
extern PyArray_Descr* BFLOAT16_DESCR;
|
||||
|
||||
// Register custom NumPy types.
|
||||
//
|
||||
// This function must be called in order to be able to map TensorFlow
|
||||
// data types which do not have a corresponding standard NumPy data type
|
||||
// (e.g. bfloat16 or qint8).
|
||||
//
|
||||
// TODO(b/144230631): The name is slightly misleading, as the function only
|
||||
// registers bfloat16 and defines structured aliases for other data types
|
||||
// (e.g. qint8).
|
||||
void MaybeRegisterCustomNumPyTypes();
|
||||
|
||||
// Returns a NumPy data type matching a given tensorflow::DataType. If the
|
||||
// function call succeeds, the caller is responsible for DECREF'ing the
|
||||
// resulting PyArray_Descr*.
|
||||
//
|
||||
// NumPy does not support quantized integer types, so TensorFlow defines
|
||||
// structured aliases for them, e.g. tf.qint8 is represented as
|
||||
// np.dtype([("qint8", np.int8)]). However, for historical reasons this
|
||||
// function does not use these aliases, and instead returns the *aliased*
|
||||
// types (np.int8 in the example).
|
||||
// TODO(b/144230631): Return an alias instead of the aliased type.
|
||||
Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr);
|
||||
|
||||
// Returns a tensorflow::DataType corresponding to a given NumPy data type.
|
||||
Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
|
@ -175,7 +175,7 @@ def get_pybind_export_symbols(symbols_file, lib_paths_file):
|
||||
else:
|
||||
# If not a section header and not an empty line, then it's a symbol
|
||||
# line. e.g. `tensorflow::swig::IsSequence`
|
||||
symbols[curr_lib].append(line)
|
||||
symbols[curr_lib].append(re.escape(line))
|
||||
|
||||
lib_paths = []
|
||||
with open(lib_paths_file, "r") as f:
|
||||
|
@ -43,10 +43,6 @@ tensorflow::tfprof::SerializeToString
|
||||
[graph_analyzer_tool] # graph_analyzer
|
||||
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
|
||||
|
||||
[bfloat16_lib] # bfloat16
|
||||
tensorflow::RegisterNumpyBfloat16
|
||||
tensorflow::Bfloat16PyType
|
||||
|
||||
[events_writer] # events_writer
|
||||
tensorflow::EventsWriter::Init
|
||||
tensorflow::EventsWriter::InitWithSuffix
|
||||
@ -189,3 +185,13 @@ tensorflow::Set_TF_Status_from_Status
|
||||
|
||||
[context] # tfe
|
||||
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||
|
||||
[ndarray_tensor_types] # _dtypes
|
||||
tensorflow::MaybeRegisterCustomNumPyTypes
|
||||
tensorflow::BFLOAT16_DESCR
|
||||
tensorflow::QINT8_DESCR
|
||||
tensorflow::QINT16_DESCR
|
||||
tensorflow::QINT32_DESCR
|
||||
tensorflow::QUINT8_DESCR
|
||||
tensorflow::QUINT16_DESCR
|
||||
tensorflow::RESOURCE_DESCR
|
||||
|
Loading…
Reference in New Issue
Block a user