Replace the mechanism used to register & look up Python types from c code in tensorflow/python/util.h with one that supports non-type symbols as well.

PiperOrigin-RevId: 316904361
Change-Id: I4ec98c861742efddcebd140ff9e1a6ff567cc94c
This commit is contained in:
Edward Loper 2020-06-17 09:33:05 -07:00 committed by TensorFlower Gardener
parent 51ccd6911b
commit b780ee931b
4 changed files with 56 additions and 20 deletions
tensorflow
python/util
tools/def_file_filter

View File

@ -29,15 +29,25 @@ limitations under the License.
namespace tensorflow {
namespace swig {
std::unordered_map<string, PyObject*>* PythonTypesMap() {
namespace {
string PyObjectToString(PyObject* o);
} // namespace
std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
static auto* m = new std::unordered_map<string, PyObject*>();
return m;
}
PyObject* GetRegisteredType(const string& key) {
auto* m = PythonTypesMap();
auto it = m->find(key);
if (it == m->end()) return nullptr;
PyObject* GetRegisteredPyObject(const string& name) {
const auto* m = RegisteredPyObjectMap();
auto it = m->find(name);
if (it == m->end()) {
PyErr_SetString(PyExc_TypeError,
tensorflow::strings::StrCat("No object with name ", name,
" has been registered.")
.c_str());
return nullptr;
}
return it->second;
}
@ -49,26 +59,35 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) {
.c_str());
return nullptr;
}
return RegisterPyObject(type_name, type);
}
PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
string key;
if (PyBytes_Check(type_name)) {
key = PyBytes_AsString(type_name);
}
if (PyBytes_Check(name)) {
key = PyBytes_AsString(name);
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(type_name)) {
key = PyUnicode_AsUTF8(type_name);
}
} else if (PyUnicode_Check(name)) {
key = PyUnicode_AsUTF8(name);
#endif
if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
} else {
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
"Type already registered for ", key)
"Expected name to be a str, got",
PyObjectToString(name))
.c_str());
return nullptr;
}
Py_INCREF(type);
PythonTypesMap()->emplace(key, type);
auto* m = RegisteredPyObjectMap();
if (m->find(key) != m->end()) {
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
"Value already registered for ", key)
.c_str());
return nullptr;
}
Py_INCREF(value);
m->emplace(key, value);
Py_RETURN_NONE;
}
@ -196,7 +215,7 @@ class CachedTypeCheck {
// Returns 0 otherwise.
// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
PyObject* type_obj = GetRegisteredType(type_name);
PyObject* type_obj = GetRegisteredPyObject(type_name);
if (TF_PREDICT_FALSE(type_obj == nullptr)) {
PyErr_SetString(PyExc_RuntimeError,
tensorflow::strings::StrCat(
@ -513,7 +532,8 @@ class AttrsValueIterator : public ValueIterator {
};
bool IsSparseTensorValueType(PyObject* o) {
PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
PyObject* sparse_tensor_value_type =
GetRegisteredPyObject("SparseTensorValue");
if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
return false;
}

View File

@ -19,6 +19,8 @@ limitations under the License.
#include <Python.h>
#include <string>
namespace tensorflow {
namespace swig {
@ -270,11 +272,20 @@ PyObject* FlattenForData(PyObject* nested);
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
bool check_types);
// RegisterType is used to pass PyTypeObject (which is defined in python) for an
// arbitrary identifier `type_name` into C++.
// Registers a Python object so it can be looked up from c++. The set of
// valid names, and the expected values for those names, are listed in
// the documentation for `RegisteredPyObjects`. Returns PyNone.
PyObject* RegisterPyObject(PyObject* name, PyObject* value);
// Variant of RegisterPyObject that requires the object's value to be a type.
PyObject* RegisterType(PyObject* type_name, PyObject* type);
} // namespace swig
// Returns a borrowed reference to an object that was registered with
// RegisterPyObject. (Do not call PY_DECREF on the result).
PyObject* GetRegisteredPyObject(const std::string& name);
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_

View File

@ -30,6 +30,10 @@ PYBIND11_MODULE(_pywrap_utils, m) {
return tensorflow::PyoOrThrow(
tensorflow::swig::RegisterType(type_name.ptr(), type.ptr()));
});
m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) {
return tensorflow::PyoOrThrow(
tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr()));
});
m.def(
"IsTensor",
[](const py::handle& o) {

View File

@ -17,6 +17,7 @@ tensorflow::swig::Flatten
tensorflow::swig::IsSequenceForData
tensorflow::swig::FlattenForData
tensorflow::swig::AssertSameStructureForData
tensorflow::swig::RegisterPyObject
tensorflow::swig::RegisterType
tensorflow::swig::IsEagerTensorSlow