Replace the mechanism used to register & look up Python types from c code in tensorflow/python/util.h with one that supports non-type symbols as well.
PiperOrigin-RevId: 315013002 Change-Id: Ia99673bfc091bb29d2ea820b0ff253eb63d80689
This commit is contained in:
parent
f5f25ec023
commit
f6db19cc34
@ -29,25 +29,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
namespace {
|
||||
string PyObjectToString(PyObject* o);
|
||||
} // namespace
|
||||
|
||||
std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
|
||||
std::unordered_map<string, PyObject*>* PythonTypesMap() {
|
||||
static auto* m = new std::unordered_map<string, PyObject*>();
|
||||
return m;
|
||||
}
|
||||
|
||||
PyObject* GetRegisteredPyObject(const string& name) {
|
||||
const auto* m = RegisteredPyObjectMap();
|
||||
auto it = m->find(name);
|
||||
if (it == m->end()) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat("No object with name ", name,
|
||||
" has been registered.")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* GetRegisteredType(const string& key) {
|
||||
auto* m = PythonTypesMap();
|
||||
auto it = m->find(key);
|
||||
if (it == m->end()) return nullptr;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
@ -59,35 +49,26 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) {
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return RegisterPyObject(type_name, type);
|
||||
}
|
||||
|
||||
PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
|
||||
string key;
|
||||
if (PyBytes_Check(name)) {
|
||||
key = PyBytes_AsString(name);
|
||||
if (PyBytes_Check(type_name)) {
|
||||
key = PyBytes_AsString(type_name);
|
||||
}
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
} else if (PyUnicode_Check(name)) {
|
||||
key = PyUnicode_AsUTF8(name);
|
||||
if (PyUnicode_Check(type_name)) {
|
||||
key = PyUnicode_AsUTF8(type_name);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
|
||||
if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
|
||||
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
|
||||
"Expected name to be a str, got",
|
||||
PyObjectToString(name))
|
||||
"Type already registered for ", key)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* m = RegisteredPyObjectMap();
|
||||
if (m->find(key) != m->end()) {
|
||||
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
|
||||
"Value already registered for ", key)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Py_INCREF(value);
|
||||
m->emplace(key, value);
|
||||
Py_INCREF(type);
|
||||
PythonTypesMap()->emplace(key, type);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
@ -215,7 +196,7 @@ class CachedTypeCheck {
|
||||
// Returns 0 otherwise.
|
||||
// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
|
||||
int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
|
||||
PyObject* type_obj = GetRegisteredPyObject(type_name);
|
||||
PyObject* type_obj = GetRegisteredType(type_name);
|
||||
if (TF_PREDICT_FALSE(type_obj == nullptr)) {
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
tensorflow::strings::StrCat(
|
||||
@ -532,8 +513,7 @@ class AttrsValueIterator : public ValueIterator {
|
||||
};
|
||||
|
||||
bool IsSparseTensorValueType(PyObject* o) {
|
||||
PyObject* sparse_tensor_value_type =
|
||||
GetRegisteredPyObject("SparseTensorValue");
|
||||
PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
|
||||
if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -19,8 +19,6 @@ limitations under the License.
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
@ -272,20 +270,11 @@ PyObject* FlattenForData(PyObject* nested);
|
||||
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
|
||||
bool check_types);
|
||||
|
||||
// Registers a Python object so it can be looked up from c++. The set of
|
||||
// valid names, and the expected values for those names, are listed in
|
||||
// the documentation for `RegisteredPyObjects`. Returns PyNone.
|
||||
PyObject* RegisterPyObject(PyObject* name, PyObject* value);
|
||||
|
||||
// Variant of RegisterPyObject that requires the object's value to be a type.
|
||||
// RegisterType is used to pass PyTypeObject (which is defined in python) for an
|
||||
// arbitrary identifier `type_name` into C++.
|
||||
PyObject* RegisterType(PyObject* type_name, PyObject* type);
|
||||
|
||||
} // namespace swig
|
||||
|
||||
// Returns a borrowed reference to an object that was registered with
|
||||
// RegisterPyObject. (Do not call PY_DECREF on the result).
|
||||
PyObject* GetRegisteredPyObject(const std::string& name);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_
|
||||
|
@ -30,10 +30,6 @@ PYBIND11_MODULE(_pywrap_utils, m) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
tensorflow::swig::RegisterType(type_name.ptr(), type.ptr()));
|
||||
});
|
||||
m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) {
|
||||
return tensorflow::PyoOrThrow(
|
||||
tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr()));
|
||||
});
|
||||
m.def(
|
||||
"IsTensor",
|
||||
[](const py::handle& o) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user