Replace the mechanism used to register & look up Python types from c code in tensorflow/python/util.h with one that supports non-type symbols as well.
PiperOrigin-RevId: 316904361 Change-Id: I4ec98c861742efddcebd140ff9e1a6ff567cc94c
This commit is contained in:
parent
51ccd6911b
commit
b780ee931b
tensorflow
@ -29,15 +29,25 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
std::unordered_map<string, PyObject*>* PythonTypesMap() {
|
||||
namespace {
|
||||
string PyObjectToString(PyObject* o);
|
||||
} // namespace
|
||||
|
||||
std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
|
||||
static auto* m = new std::unordered_map<string, PyObject*>();
|
||||
return m;
|
||||
}
|
||||
|
||||
PyObject* GetRegisteredType(const string& key) {
|
||||
auto* m = PythonTypesMap();
|
||||
auto it = m->find(key);
|
||||
if (it == m->end()) return nullptr;
|
||||
PyObject* GetRegisteredPyObject(const string& name) {
|
||||
const auto* m = RegisteredPyObjectMap();
|
||||
auto it = m->find(name);
|
||||
if (it == m->end()) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat("No object with name ", name,
|
||||
" has been registered.")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
@ -49,26 +59,35 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) {
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return RegisterPyObject(type_name, type);
|
||||
}
|
||||
|
||||
PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
|
||||
string key;
|
||||
if (PyBytes_Check(type_name)) {
|
||||
key = PyBytes_AsString(type_name);
|
||||
}
|
||||
if (PyBytes_Check(name)) {
|
||||
key = PyBytes_AsString(name);
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
if (PyUnicode_Check(type_name)) {
|
||||
key = PyUnicode_AsUTF8(type_name);
|
||||
}
|
||||
} else if (PyUnicode_Check(name)) {
|
||||
key = PyUnicode_AsUTF8(name);
|
||||
#endif
|
||||
|
||||
if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
|
||||
"Type already registered for ", key)
|
||||
"Expected name to be a str, got",
|
||||
PyObjectToString(name))
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Py_INCREF(type);
|
||||
PythonTypesMap()->emplace(key, type);
|
||||
auto* m = RegisteredPyObjectMap();
|
||||
if (m->find(key) != m->end()) {
|
||||
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
|
||||
"Value already registered for ", key)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Py_INCREF(value);
|
||||
m->emplace(key, value);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
@ -196,7 +215,7 @@ class CachedTypeCheck {
|
||||
// Returns 0 otherwise.
|
||||
// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
|
||||
int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
|
||||
PyObject* type_obj = GetRegisteredType(type_name);
|
||||
PyObject* type_obj = GetRegisteredPyObject(type_name);
|
||||
if (TF_PREDICT_FALSE(type_obj == nullptr)) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
tensorflow::strings::StrCat(
|
||||
@ -513,7 +532,8 @@ class AttrsValueIterator : public ValueIterator {
|
||||
};
|
||||
|
||||
bool IsSparseTensorValueType(PyObject* o) {
|
||||
PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
|
||||
PyObject* sparse_tensor_value_type =
|
||||
GetRegisteredPyObject("SparseTensorValue");
|
||||
if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
@ -270,11 +272,20 @@ PyObject* FlattenForData(PyObject* nested);
|
||||
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
|
||||
bool check_types);
|
||||
|
||||
// RegisterType is used to pass PyTypeObject (which is defined in python) for an
|
||||
// arbitrary identifier `type_name` into C++.
|
||||
// Registers a Python object so it can be looked up from c++. The set of
|
||||
// valid names, and the expected values for those names, are listed in
|
||||
// the documentation for `RegisteredPyObjects`. Returns PyNone.
|
||||
PyObject* RegisterPyObject(PyObject* name, PyObject* value);
|
||||
|
||||
// Variant of RegisterPyObject that requires the object's value to be a type.
|
||||
PyObject* RegisterType(PyObject* type_name, PyObject* type);
|
||||
|
||||
} // namespace swig
|
||||
|
||||
// Returns a borrowed reference to an object that was registered with
|
||||
// RegisterPyObject. (Do not call PY_DECREF on the result).
|
||||
PyObject* GetRegisteredPyObject(const std::string& name);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_
|
||||
|
@ -30,6 +30,10 @@ PYBIND11_MODULE(_pywrap_utils, m) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
tensorflow::swig::RegisterType(type_name.ptr(), type.ptr()));
|
||||
});
|
||||
m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr()));
|
||||
});
|
||||
m.def(
|
||||
"IsTensor",
|
||||
[](const py::handle& o) {
|
||||
|
@ -17,6 +17,7 @@ tensorflow::swig::Flatten
|
||||
tensorflow::swig::IsSequenceForData
|
||||
tensorflow::swig::FlattenForData
|
||||
tensorflow::swig::AssertSameStructureForData
|
||||
tensorflow::swig::RegisterPyObject
|
||||
tensorflow::swig::RegisterType
|
||||
tensorflow::swig::IsEagerTensorSlow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user