From 94a73f439383defabce00ea3cffa20920c2a7dc6 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 5 Jun 2020 14:24:34 -0700 Subject: [PATCH] Replace the mechanism used to register & look up Python types from c code in tensorflow/python/util.h with one that supports non-type symbols as well. PiperOrigin-RevId: 314996128 Change-Id: I8edca2552d5d45cf74a1f5fb00bc88996af033da --- tensorflow/python/util/util.cc | 56 +++++++++++++++++--------- tensorflow/python/util/util.h | 15 ++++++- tensorflow/python/util/util_wrapper.cc | 4 ++ 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 1d0dd695d74..cf8581443e7 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -29,15 +29,25 @@ limitations under the License. namespace tensorflow { namespace swig { -std::unordered_map* PythonTypesMap() { +namespace { +string PyObjectToString(PyObject* o); +} // namespace + +std::unordered_map* RegisteredPyObjectMap() { static auto* m = new std::unordered_map(); return m; } -PyObject* GetRegisteredType(const string& key) { - auto* m = PythonTypesMap(); - auto it = m->find(key); - if (it == m->end()) return nullptr; +PyObject* GetRegisteredPyObject(const string& name) { + const auto* m = RegisteredPyObjectMap(); + auto it = m->find(name); + if (it == m->end()) { + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat("No object with name ", name, + " has been registered.") + .c_str()); + return nullptr; + } return it->second; } @@ -49,26 +59,35 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) { .c_str()); return nullptr; } + return RegisterPyObject(type_name, type); +} +PyObject* RegisterPyObject(PyObject* name, PyObject* value) { string key; - if (PyBytes_Check(type_name)) { - key = PyBytes_AsString(type_name); - } + if (PyBytes_Check(name)) { + key = PyBytes_AsString(name); #if PY_MAJOR_VERSION >= 3 - if (PyUnicode_Check(type_name)) { - key = PyUnicode_AsUTF8(type_name); - } + } else if (PyUnicode_Check(name)) { + key = PyUnicode_AsUTF8(name); #endif - - if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) { + } else { PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( - "Type already registered for ", key) + "Expected name to be a str, got", + PyObjectToString(name)) .c_str()); return nullptr; } - Py_INCREF(type); - PythonTypesMap()->emplace(key, type); + auto* m = RegisteredPyObjectMap(); + if (m->find(key) != m->end()) { + PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( + "Value already registered for ", key) + .c_str()); + return nullptr; + } + + Py_INCREF(value); + m->emplace(key, value); Py_RETURN_NONE; } @@ -196,7 +215,7 @@ class CachedTypeCheck { // Returns 0 otherwise. // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.) int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) { - PyObject* type_obj = GetRegisteredType(type_name); + PyObject* type_obj = GetRegisteredPyObject(type_name); if (TF_PREDICT_FALSE(type_obj == nullptr)) { PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat( @@ -513,7 +532,8 @@ class AttrsValueIterator : public ValueIterator { }; bool IsSparseTensorValueType(PyObject* o) { - PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue"); + PyObject* sparse_tensor_value_type = + GetRegisteredPyObject("SparseTensorValue"); if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) { return false; } diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 23438b43c53..fc0b864416e 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -19,6 +19,8 @@ limitations under the License. #include +#include + namespace tensorflow { namespace swig { @@ -270,11 +272,20 @@ PyObject* FlattenForData(PyObject* nested); PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2, bool check_types); -// RegisterType is used to pass PyTypeObject (which is defined in python) for an -// arbitrary identifier `type_name` into C++. +// Registers a Python object so it can be looked up from c++. The set of +// valid names, and the expected values for those names, are listed in +// the documentation for `RegisteredPyObjects`. Returns PyNone. +PyObject* RegisterPyObject(PyObject* name, PyObject* value); + +// Variant of RegisterPyObject that requires the object's value to be a type. PyObject* RegisterType(PyObject* type_name, PyObject* type); } // namespace swig + +// Returns a borrowed reference to an object that was registered with +// RegisterPyObject. (Do not call PY_DECREF on the result). +PyObject* GetRegisteredPyObject(const std::string& name); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_ diff --git a/tensorflow/python/util/util_wrapper.cc b/tensorflow/python/util/util_wrapper.cc index dd74306413c..63c70d785cc 100644 --- a/tensorflow/python/util/util_wrapper.cc +++ b/tensorflow/python/util/util_wrapper.cc @@ -30,6 +30,10 @@ PYBIND11_MODULE(_pywrap_utils, m) { return tensorflow::PyoOrThrow( tensorflow::swig::RegisterType(type_name.ptr(), type.ptr())); }); + m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) { + return tensorflow::PyoOrThrow( + tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr())); + }); m.def( "IsTensor", [](const py::handle& o) {