Replace the mechanism used to register & look up Python types from c code in tensorflow/python/util.h with one that supports non-type symbols as well.

PiperOrigin-RevId: 315013002
Change-Id: Ia99673bfc091bb29d2ea820b0ff253eb63d80689
This commit is contained in:
Gunhan Gulsoy 2020-06-05 16:02:28 -07:00 committed by TensorFlower Gardener
parent f5f25ec023
commit f6db19cc34
3 changed files with 20 additions and 55 deletions

View File

@ -29,25 +29,15 @@ limitations under the License.
namespace tensorflow {
namespace swig {
namespace {
string PyObjectToString(PyObject* o);
} // namespace
std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
std::unordered_map<string, PyObject*>* PythonTypesMap() {
static auto* m = new std::unordered_map<string, PyObject*>();
return m;
}
PyObject* GetRegisteredPyObject(const string& name) {
const auto* m = RegisteredPyObjectMap();
auto it = m->find(name);
if (it == m->end()) {
PyErr_SetString(PyExc_TypeError,
tensorflow::strings::StrCat("No object with name ", name,
" has been registered.")
.c_str());
return nullptr;
}
PyObject* GetRegisteredType(const string& key) {
auto* m = PythonTypesMap();
auto it = m->find(key);
if (it == m->end()) return nullptr;
return it->second;
}
@ -59,35 +49,26 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) {
.c_str());
return nullptr;
}
return RegisterPyObject(type_name, type);
}
PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
string key;
if (PyBytes_Check(name)) {
key = PyBytes_AsString(name);
if (PyBytes_Check(type_name)) {
key = PyBytes_AsString(type_name);
}
#if PY_MAJOR_VERSION >= 3
} else if (PyUnicode_Check(name)) {
key = PyUnicode_AsUTF8(name);
if (PyUnicode_Check(type_name)) {
key = PyUnicode_AsUTF8(type_name);
}
#endif
} else {
if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
"Expected name to be a str, got",
PyObjectToString(name))
"Type already registered for ", key)
.c_str());
return nullptr;
}
auto* m = RegisteredPyObjectMap();
if (m->find(key) != m->end()) {
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
"Value already registered for ", key)
.c_str());
return nullptr;
}
Py_INCREF(value);
m->emplace(key, value);
Py_INCREF(type);
PythonTypesMap()->emplace(key, type);
Py_RETURN_NONE;
}
@ -215,7 +196,7 @@ class CachedTypeCheck {
// Returns 0 otherwise.
// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
PyObject* type_obj = GetRegisteredPyObject(type_name);
PyObject* type_obj = GetRegisteredType(type_name);
if (TF_PREDICT_FALSE(type_obj == nullptr)) {
PyErr_SetString(PyExc_RuntimeError,
tensorflow::strings::StrCat(
@ -532,8 +513,7 @@ class AttrsValueIterator : public ValueIterator {
};
bool IsSparseTensorValueType(PyObject* o) {
PyObject* sparse_tensor_value_type =
GetRegisteredPyObject("SparseTensorValue");
PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
return false;
}

View File

@ -19,8 +19,6 @@ limitations under the License.
#include <Python.h>
#include <string>
namespace tensorflow {
namespace swig {
@ -272,20 +270,11 @@ PyObject* FlattenForData(PyObject* nested);
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
bool check_types);
// Registers a Python object so it can be looked up from c++. The set of
// valid names, and the expected values for those names, are listed in
// the documentation for `RegisteredPyObjects`. Returns PyNone.
PyObject* RegisterPyObject(PyObject* name, PyObject* value);
// Variant of RegisterPyObject that requires the object's value to be a type.
// RegisterType is used to pass PyTypeObject (which is defined in python) for an
// arbitrary identifier `type_name` into C++.
PyObject* RegisterType(PyObject* type_name, PyObject* type);
} // namespace swig
// Returns a borrowed reference to an object that was registered with
// RegisterPyObject. (Do not call PY_DECREF on the result).
PyObject* GetRegisteredPyObject(const std::string& name);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_

View File

@ -30,10 +30,6 @@ PYBIND11_MODULE(_pywrap_utils, m) {
return tensorflow::PyoOrThrow(
tensorflow::swig::RegisterType(type_name.ptr(), type.ptr()));
});
m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) {
return tensorflow::PyoOrThrow(
tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr()));
});
m.def(
"IsTensor",
[](const py::handle& o) {