From f6db19cc34dd38d3f802e7c5d720bacb497e8079 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Fri, 5 Jun 2020 16:02:28 -0700 Subject: [PATCH] Replace the mechanism used to register & look up Python types from c code in tensorflow/python/util.h with one that supports non-type symbols as well. PiperOrigin-RevId: 315013002 Change-Id: Ia99673bfc091bb29d2ea820b0ff253eb63d80689 --- tensorflow/python/util/util.cc | 56 +++++++++----------------- tensorflow/python/util/util.h | 15 +------ tensorflow/python/util/util_wrapper.cc | 4 -- 3 files changed, 20 insertions(+), 55 deletions(-) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index cf8581443e7..1d0dd695d74 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -29,25 +29,15 @@ limitations under the License. namespace tensorflow { namespace swig { -namespace { -string PyObjectToString(PyObject* o); -} // namespace - -std::unordered_map* RegisteredPyObjectMap() { +std::unordered_map* PythonTypesMap() { static auto* m = new std::unordered_map(); return m; } -PyObject* GetRegisteredPyObject(const string& name) { - const auto* m = RegisteredPyObjectMap(); - auto it = m->find(name); - if (it == m->end()) { - PyErr_SetString(PyExc_TypeError, - tensorflow::strings::StrCat("No object with name ", name, - " has been registered.") - .c_str()); - return nullptr; - } +PyObject* GetRegisteredType(const string& key) { + auto* m = PythonTypesMap(); + auto it = m->find(key); + if (it == m->end()) return nullptr; return it->second; } @@ -59,35 +49,26 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) { .c_str()); return nullptr; } - return RegisterPyObject(type_name, type); -} -PyObject* RegisterPyObject(PyObject* name, PyObject* value) { string key; - if (PyBytes_Check(name)) { - key = PyBytes_AsString(name); + if (PyBytes_Check(type_name)) { + key = PyBytes_AsString(type_name); + } #if PY_MAJOR_VERSION >= 3 - } else if (PyUnicode_Check(name)) { - key = PyUnicode_AsUTF8(name); + if (PyUnicode_Check(type_name)) { + key = PyUnicode_AsUTF8(type_name); + } #endif - } else { + + if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) { PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( - "Expected name to be a str, got", - PyObjectToString(name)) + "Type already registered for ", key) .c_str()); return nullptr; } - auto* m = RegisteredPyObjectMap(); - if (m->find(key) != m->end()) { - PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( - "Value already registered for ", key) - .c_str()); - return nullptr; - } - - Py_INCREF(value); - m->emplace(key, value); + Py_INCREF(type); + PythonTypesMap()->emplace(key, type); Py_RETURN_NONE; } @@ -215,7 +196,7 @@ class CachedTypeCheck { // Returns 0 otherwise. // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.) int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) { - PyObject* type_obj = GetRegisteredPyObject(type_name); + PyObject* type_obj = GetRegisteredType(type_name); if (TF_PREDICT_FALSE(type_obj == nullptr)) { PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat( @@ -532,8 +513,7 @@ class AttrsValueIterator : public ValueIterator { }; bool IsSparseTensorValueType(PyObject* o) { - PyObject* sparse_tensor_value_type = - GetRegisteredPyObject("SparseTensorValue"); + PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue"); if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) { return false; } diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index fc0b864416e..23438b43c53 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -19,8 +19,6 @@ limitations under the License. #include -#include - namespace tensorflow { namespace swig { @@ -272,20 +270,11 @@ PyObject* FlattenForData(PyObject* nested); PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2, bool check_types); -// Registers a Python object so it can be looked up from c++. The set of -// valid names, and the expected values for those names, are listed in -// the documentation for `RegisteredPyObjects`. Returns PyNone. -PyObject* RegisterPyObject(PyObject* name, PyObject* value); - -// Variant of RegisterPyObject that requires the object's value to be a type. +// RegisterType is used to pass PyTypeObject (which is defined in python) for an +// arbitrary identifier `type_name` into C++. PyObject* RegisterType(PyObject* type_name, PyObject* type); } // namespace swig - -// Returns a borrowed reference to an object that was registered with -// RegisterPyObject. (Do not call PY_DECREF on the result). -PyObject* GetRegisteredPyObject(const std::string& name); - } // namespace tensorflow #endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_ diff --git a/tensorflow/python/util/util_wrapper.cc b/tensorflow/python/util/util_wrapper.cc index 63c70d785cc..dd74306413c 100644 --- a/tensorflow/python/util/util_wrapper.cc +++ b/tensorflow/python/util/util_wrapper.cc @@ -30,10 +30,6 @@ PYBIND11_MODULE(_pywrap_utils, m) { return tensorflow::PyoOrThrow( tensorflow::swig::RegisterType(type_name.ptr(), type.ptr())); }); - m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) { - return tensorflow::PyoOrThrow( - tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr())); - }); m.def( "IsTensor", [](const py::handle& o) {