Add helper to test whether an object implements the tf.DispatchableType protocol.

PiperOrigin-RevId: 324016472
Change-Id: I760151498d9ddb116b296031d835f78f71ca4e34
This commit is contained in:
Edward Loper 2020-07-30 09:19:50 -07:00 committed by TensorFlower Gardener
parent 1cc134c8bf
commit 9679d65211
2 changed files with 20 additions and 0 deletions
tensorflow/python/util

View File

@ -361,6 +361,16 @@ int IsSequenceHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
// Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
// Returns 0 otherwise.
int IsDispatchableHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return PyObject_HasAttrString(
reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
});
return check_cache->CachedLookup(o);
}
// ValueIterator interface
class ValueIterator {
public:
@ -917,6 +927,7 @@ bool IsResourceVariable(PyObject* o) {
}
bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
bool IsTuple(PyObject* o) {
tensorflow::Safe_PyObjectPtr wrapped;

View File

@ -115,6 +115,15 @@ bool IsTuple(PyObject* o);
// True if the sequence subclasses mapping.
bool IsMappingView(PyObject* o);
// Returns a true if its input has a `__tf_dispatch__` attribute.
//
// Args:
// o: the input to be checked.
//
// Returns:
// True if `o` has a `__tf_dispatch__` attribute.
bool IsDispatchable(PyObject* o);
// A version of PyMapping_Keys that works in C++11
//
// Args: