Dispatch handler for Python APIs. For background, see the RFC for [TensorFlow Extension Types](https://github.com/tensorflow/community/pull/269).

PiperOrigin-RevId: 338353976
Change-Id: I7f155854973fb496c5c0f43eb24a46aa7418c8d6
This commit is contained in:
Edward Loper 2020-10-21 15:34:56 -07:00 committed by TensorFlower Gardener
parent 0ed710fb76
commit 6f980e4a05
6 changed files with 706 additions and 0 deletions

View File

@ -1382,6 +1382,7 @@ py_library(
":_pywrap_kernel_registry",
":_pywrap_py_exception_registry",
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
":_pywrap_python_api_dispatcher",
":_pywrap_python_op_gen",
":_pywrap_quantize_training",
":_pywrap_stacktrace_handler",
@ -1752,6 +1753,45 @@ tf_py_test(
tfrt_enabled = True,
)
cc_library(
name = "python_api_dispatcher",
srcs = ["framework/python_api_dispatcher.cc"],
hdrs = ["framework/python_api_dispatcher.h"],
deps = [
":cpp_python_util",
":safe_pyobject_ptr",
"//tensorflow/core/platform:logging",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
],
)
# Note: this target is only used by python_api_dispatcher_test.
tf_python_pybind_extension(
name = "_pywrap_python_api_dispatcher",
# testonly = True,
srcs = ["framework/python_api_dispatcher_wrapper.cc"],
hdrs = ["framework/python_api_dispatcher.h"],
module_name = "_pywrap_python_api_dispatcher",
deps = [
":safe_pyobject_ptr_required_hdrs",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@pybind11",
],
)
tf_py_test(
name = "python_api_dispatcher_test",
srcs = ["framework/python_api_dispatcher_test.py"],
python_version = "PY3",
tags = ["no_pip"],
deps = [
":_pywrap_python_api_dispatcher",
":client_testlib",
],
)
py_library(
name = "framework_ops", # "ops" is already the name of a deprecated target
srcs = ["framework/ops.py"],
@ -6049,6 +6089,7 @@ pywrap_tensorflow_macro(
":pybind11_lib",
":pybind11_status",
":pybind11_proto",
":python_api_dispatcher",
":python_op_gen",
":safe_pyobject_ptr",
":tf_session_helper",
@ -6121,6 +6162,7 @@ filegroup(
":numpy_lib", # checkpoint_reader
":py_exception_registry", # py_exception_registry
":py_func_lib", # py_func
":python_api_dispatcher", # python_api_dispatcher
":python_op_gen", # python_op_gen
":safe_ptr", # checkpoint_reader
"//tensorflow/c:checkpoint_reader", # checkpoint_reader

View File

@ -0,0 +1,220 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/framework/python_api_dispatcher.h"
#include <set>
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
#include "tensorflow/python/util/util.h"
namespace tensorflow {
using ParamInfo = PythonAPIDispatcher::ParamInfo;
// List of python types to check for dispatch. In most cases, this vector
// will have size zero or one; and sizes greater than 3 should be rare.
using TypeList = absl::InlinedVector<PyTypeObject*, 3>;
namespace {
// Returns the __tf__dispatch__ attribute of `obj`.
Safe_PyObjectPtr GetAttr_TFDispatch(PyObject* obj) {
#if PY_MAJOR_VERSION < 3
// Python 2.x:
static PyObject* attr = PyString_InternFromString("__tf_dispatch__");
#else
// Python 3.x:
static PyObject* attr = PyUnicode_InternFromString("__tf_dispatch__");
#endif
return Safe_PyObjectPtr(PyObject_GetAttr(obj, attr));
}
// Searches `params` for dispatchable types, and returns a vector of borrowed
// references to those types. Removes consecutive duplicates (i.e., if a
// dispatchable parameter has the same type as the previously encountered
// dispatcahble parameter, then it's type is not added again), so the result
// will usually have a length of zero or one; but in the general case, it may be
// longer, and may contain (nonconsecutive) duplicates.
//
// Assumes that `params` is a tuple, and that all parameter indices in
// `dispatch_params` and `dispatch_list_params` are valid.
TypeList FindDispatchTypes(PyObject* params,
const std::vector<ParamInfo>& dispatchable_params) {
TypeList dispatch_types;
for (const auto& param : dispatchable_params) {
DCHECK_GE(param.index, 0);
DCHECK_LT(param.index, PyTuple_GET_SIZE(params));
PyObject* value = PyTuple_GET_ITEM(params, param.index);
if (param.is_list) {
DCHECK(PyList_Check(value));
Py_ssize_t num_items = PyList_Size(value);
for (Py_ssize_t i = 0; i < num_items; ++i) {
PyObject* item = PyList_GET_ITEM(value, i);
// TODO(b/164980194) Consider changing IsDispatchable to not use a
// cache. This may impact efficiency (needs to be measured), but would
// allow us to support monkey-patching classes to be dispatchable.
if (swig::IsDispatchable(item)) {
if (dispatch_types.empty() ||
value->ob_type != dispatch_types.back()) {
dispatch_types.push_back(item->ob_type);
}
}
}
} else {
if (swig::IsDispatchable(value)) {
if (dispatch_types.empty() || value->ob_type != dispatch_types.back()) {
dispatch_types.push_back(value->ob_type);
}
}
}
}
return dispatch_types;
}
// Removes duplicates from `dispatch_types`, and moves any subtypes to
// before their supertypes. Note: this method is only called when
// `dispatch_types.size() > 1`.
void SortDispatchTypes(TypeList& dispatch_types) {
// Remove duplicates. Note: this is O(n^2) in the number of dispatchable
// types, but we expect this number to be very small in almost every case
// (usually zero, sometimes one, and rarely larger than two).
for (int i = 0; i < dispatch_types.size() - 1; ++i) {
if (dispatch_types[i] == nullptr) continue;
for (int j = i + 1; j < dispatch_types.size(); ++j) {
if (dispatch_types[i] == dispatch_types[j]) {
dispatch_types[j] = nullptr; // mark duplicate
}
}
}
dispatch_types.erase(
std::remove_if(dispatch_types.begin(), dispatch_types.end(),
[](PyTypeObject* t) { return t == nullptr; }),
dispatch_types.end());
// Move subclasses before superclasses. As above, this is O(n^2), but we
// expect n to be small.
TypeList sorted;
TypeList subtypes;
for (int i = 0; i < dispatch_types.size(); ++i) {
if (dispatch_types[i] == nullptr) continue;
subtypes.clear();
for (int j = i + 1; j < dispatch_types.size(); ++j) {
if (dispatch_types[j] == nullptr) continue;
if (PyType_IsSubtype(dispatch_types[j], dispatch_types[i])) {
subtypes.push_back(dispatch_types[j]);
dispatch_types[j] = nullptr; // mark as already added.
}
}
if (!subtypes.empty()) {
std::sort(subtypes.begin(), subtypes.end(), PyType_IsSubtype);
sorted.insert(sorted.end(), subtypes.begin(), subtypes.end());
}
sorted.push_back(dispatch_types[i]);
}
DCHECK_EQ(dispatch_types.size(), sorted.size());
dispatch_types.swap(sorted);
}
} // namespace
PythonAPIDispatcher::PythonAPIDispatcher(const std::string& api_name,
PyObject* api_func, int num_params,
bool right_to_left)
: api_name_(PyUnicode_FromStringAndSize(api_name.c_str(), api_name.size())),
api_func_(api_func),
num_params_(num_params),
right_to_left_(right_to_left) {
Py_INCREF(api_func);
}
bool PythonAPIDispatcher::Initialize(
std::vector<ParamInfo> dispatchable_params) {
dispatchable_params_.swap(dispatchable_params);
std::sort(dispatchable_params_.begin(), dispatchable_params_.end(),
[](const ParamInfo& a, const ParamInfo& b) -> bool {
return a.index < b.index;
});
if (right_to_left_) {
std::reverse(dispatchable_params_.begin(), dispatchable_params_.end());
}
for (const auto& p : dispatchable_params_) {
if (p.index < 0 || p.index >= num_params_) {
PyErr_SetString(
PyExc_ValueError,
absl::StrCat("PythonAPIDispatcher: dispatchable parameter index out ",
"of range: ", p.index, " not in [0, ", num_params_, ")")
.c_str());
return false;
}
}
return true;
}
PyObject* PythonAPIDispatcher::Dispatch(PyObject* params) const {
DCHECK(PyTuple_Check(params));
// TODO(b/164980194) Consider removing this check, if the caller is also
// checking/guaranteeing it (once dispatch has been integrated w/ the Python
// API handlers).
if (num_params_ != PyTuple_Size(params)) {
#if PY_MAJOR_VERSION < 3
// Python 2.x:
Safe_PyObjectPtr api_name_str(PyUnicode_AsUTF8String(api_name_.get()));
if (!api_name_str) return nullptr;
const char* api_name = PyString_AsString(api_name_str.get());
#else
// Python 3.x:
const char* api_name = PyUnicode_AsUTF8AndSize(api_name_.get(), nullptr);
#endif
PyErr_SetString(
PyExc_TypeError,
absl::StrCat(api_name ? api_name : "unknown PythonAPIDispatcher",
" expected ", num_params_, " parameters, but got ",
PyTuple_Size(params))
.c_str());
return nullptr;
}
TypeList dispatch_types = FindDispatchTypes(params, dispatchable_params_);
if (dispatch_types.empty()) {
return Py_NotImplemented;
}
if (dispatch_types.size() > 1) {
SortDispatchTypes(dispatch_types);
}
for (PyTypeObject* dispatch_type : dispatch_types) {
Safe_PyObjectPtr dispatcher =
GetAttr_TFDispatch(reinterpret_cast<PyObject*>(dispatch_type));
if (!dispatcher) return nullptr;
PyObject* result = PyObject_CallFunctionObjArgs(
dispatcher.get(), api_name_.get(), api_func_.get(), params, nullptr);
if (result != Py_NotImplemented) {
return result;
}
}
return Py_NotImplemented;
}
} // namespace tensorflow

View File

@ -0,0 +1,131 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
#include <Python.h>
#include <string>
#include <vector>
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
namespace tensorflow {
// Dispatch handler for Python APIs.
//
// A separate PythonAPIDispatcher object is created for each Python API, and
// keeps track of which parameters should be checked for dispatch.
//
// When PythonAPIDispatcher::Dispatch() is called with a tuple of
// canonicalized parameters, it checks the indicated parameters' values for
// `__tf_dispatch__` methods. If found, then this method is called with the
// following arguments: `__tf_dispatch__(api_name, api_func, canon_args)`,
// where:
//
// * `api_name` is the fully-qualified name of the python API (e.g.,
// `"tf.math.sum"`).
// * `api_func` is the function that implements the APIs for `Tensor` inputs.
// * `canon_args` is the canonicalized argument list.
//
class PythonAPIDispatcher {
public:
// Information about an API parameter that supports dispatch. `index` is the
// parameter's index in the canonicalized parameter list, and `is_list` is
// true if the parameter expects a list of values (e.g. the `values` parameter
// to `tf.concat`).
struct ParamInfo {
int index;
bool is_list;
};
// Constructs a PythonAPIDispatcher.
//
// Args:
// api_name: The fully qualified name of the API handled by this dispatcher.
// api_func: The python function for which implements the API for `Tensor`
// inputs.
// num_params: The number of canonical parameters that the API expects.
// right_to_left: If true, then the normal precedence rules (in which
// dispatchers are tried from left-to-right) are changed to try
// dispatchers from right-to-left instead. This is used for operations
// such as `__radd__`, where the normal parameter order is reversed.
PythonAPIDispatcher(const std::string& api_name, PyObject* api_func,
int num_params, bool right_to_left = false);
// Initiliaze this PythonAPIDispatcher with information about which parameters
// support dispatch. Returns true on success, or sets a python exception and
// returns false on error.
bool Initialize(std::vector<ParamInfo> dispatchable_params);
// Checks if any of the dispatchable parameters have a `__tf_dispatch__`
// method, and if so, calls them. In particular, this method:
//
// 1. Constructs an ordered list of dispatchable types.
//
// * Checks each argument that support dispatch to see if its value(s) have
// a `__tf_dispatch__` method.
// * Arguments are checked left-to-right unless `right_to_left` was set to
// True in the constructor. *Within* a list-valued parameter, elements
// are always checked left-to-right (even if `right_to_left` is True).
// * Duplicate types are removed (only the first occurrence of each type is
// kept).
// * If any type `T_sub` is a subtype of another type `T_super`, but occurs
// after `T_super` in the list of dispatchable types, then it is moved to
// just before `T_super`.
//
// 2. Tries calling each of the dispatchable types' `__tf_dispatch__` methods.
//
// * Dispatch methods are called with the following arguments:
// `__tf_dispatch__(api_name, api_func, canon_args)`
// * Dispatch methods are tried in the order described above.
// * If a dispatch method returns a value, then `Dispatch()` returns a
// new reference to that value.
// * If a dispatch method raises an exception, then `Dispatch()` returns
// null (i.e., propogates the exception).
// * If a dispatch method returns `NotImplemented`, then the dispatcher
// moves on to the next type.
//
// 3. If no dispatchers for found, or all dispatchers returned
// `NotImplemented', then the dispatcher returns a *borrowed* reference
// to `Py_NotImplemented`.
//
// Args:
// params: A `PyTuple` containing the canonicalized parameters to the API.
// All `POSITIONAL_OR_KEYWORD` arguments must be converted to positional
// arguments (`KEYWORD_ONLY` arguments are not currently supported). Any
// dispatchable parameter with `is_list=True` must have been converted to
// `PyList`.
//
// Returns:
// * If a `__tf_dispatch__` handler successfully handled the API:
// Returns a *new* reference to the handler's return value.
// * If no handler was found, or all handlers returned NotImplemented:
// Returns a *borrowed* reference to `Py_NotImplemented`.
// * On error: Sets an exception and returns `nullptr`.
PyObject* Dispatch(PyObject* params) const;
private:
Safe_PyObjectPtr api_name_;
Safe_PyObjectPtr api_func_;
int num_params_;
std::vector<ParamInfo> dispatchable_params_;
bool right_to_left_;
};
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_

View File

@ -0,0 +1,244 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.python_api_dispatcher."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python import _pywrap_python_api_dispatcher
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class Trace(object):
"""A dispatchable type that builds traces of ops it's called with."""
log = []
def __init__(self, api_name, *args):
self.api_name = api_name
self.args = args
@classmethod
def __tf_dispatch__(cls, api_name, api_func, args):
Trace.log.append("__tf_dispatch__%s" % ((cls.__name__, api_name),))
if "disabled" in str(args) or api_name == "disabled":
return NotImplemented
del api_func # not used
return cls(api_name, *args)
def __repr__(self):
return "%s%s" % (type(self).__name__, (self.api_name,) + self.args)
def __eq__(self, other):
return (type(self) is type(other) and self.api_name == other.api_name and
self.args == other.args)
class Trace2(Trace):
pass
class Trace2B(Trace2):
pass
class Trace3(Trace):
pass
class Trace4(Trace):
pass
class WeightedTensor(object):
def __init__(self, tensor, weight):
self.tensor = ops.convert_to_tensor(tensor)
self.weight = weight # Python float
@classmethod
def __tf_dispatch__(cls, api_name, api_func, args):
del api_name # unused
weights = [arg.weight for arg in args if isinstance(arg, WeightedTensor)]
tensors = [
arg.tensor if isinstance(arg, WeightedTensor) else arg for arg in args
]
tensor_result = api_func(*tensors)
avg_weight = sum(weights) / len(weights)
return cls(tensor_result, avg_weight)
@test_util.run_all_in_graph_and_eager_modes
class PythonAPIDispatcherTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def testNoDispatchableTypes(self):
add_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
self.assertEqual(add_dispatcher.Dispatch(1, 2), NotImplemented)
concat_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
"tf.concat", array_ops.concat, 2, [1], [0], False)
self.assertEqual(concat_dispatcher.Dispatch([1], 0), NotImplemented)
def testSimpleDispatchWithTrace(self):
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
x = 5
y = Trace("constant", "y")
z = Trace("constant", "z")
Trace.log.clear()
self.assertEqual(dispatcher.Dispatch(x, y), Trace("tf.math.add", x, y))
self.assertEqual(dispatcher.Dispatch(y, x), Trace("tf.math.add", y, x))
self.assertEqual(dispatcher.Dispatch(y, z), Trace("tf.math.add", y, z))
self.assertEqual(Trace.log, [
"__tf_dispatch__('Trace', 'tf.math.add')",
"__tf_dispatch__('Trace', 'tf.math.add')",
"__tf_dispatch__('Trace', 'tf.math.add')"
])
def testDispatcherReturnsNotImplemented(self):
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
x = 5
y = Trace("constant", "disabled")
z = Trace("constant", "z")
self.assertEqual(dispatcher.Dispatch(x, y), NotImplemented)
self.assertEqual(dispatcher.Dispatch(y, x), NotImplemented)
self.assertEqual(dispatcher.Dispatch(y, z), NotImplemented)
self.assertEqual(dispatcher.Dispatch(z, z), Trace("tf.math.add", z, z))
def testSimpleDispatchWithWeightedTensor(self):
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
x = 5
y = WeightedTensor([1, 2, 3], 0.6)
z = WeightedTensor([10, 20, 30], 0.2)
x_plus_y = dispatcher.Dispatch(x, y)
y_plus_x = dispatcher.Dispatch(y, x)
y_plus_z = dispatcher.Dispatch(y, z)
self.assertAllEqual(x_plus_y.tensor, [6, 7, 8])
self.assertAllEqual(y_plus_x.tensor, [6, 7, 8])
self.assertAllEqual(y_plus_z.tensor, [11, 22, 33])
self.assertEqual(x_plus_y.weight, 0.6)
self.assertEqual(y_plus_x.weight, 0.6)
self.assertEqual(y_plus_z.weight, 0.4)
def testDispatchPrecedence(self):
# We use an API for which dispatch is disabled, so all dispatchers get
# called (since this test checks the order of the dispatcher list).
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
"disabled", None, 5, [0, 1, 4], [2, 3], False)
t = Trace("constant", "t")
t2_1 = Trace2("constant", "t2_1")
t2_2 = Trace2("constant", "t2_2")
t2b = Trace2B("constant", "t2b")
t3 = Trace3("constant", "t3")
t4 = Trace4("constant", "t4")
# Three dispatchable types, none of which is a subclass of the other:
# * precedence is left-to-right.
# * duplicates are removed.
Trace.log.clear()
result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4)
self.assertEqual(result, NotImplemented)
self.assertEqual(Trace.log, [
"__tf_dispatch__('Trace2', 'disabled')",
"__tf_dispatch__('Trace3', 'disabled')",
"__tf_dispatch__('Trace4', 'disabled')"
])
# Subtypes are moved before their base types.
Trace.log.clear()
result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b)
self.assertEqual(result, NotImplemented)
self.assertEqual(Trace.log, [
"__tf_dispatch__('Trace2B', 'disabled')",
"__tf_dispatch__('Trace2', 'disabled')",
"__tf_dispatch__('Trace3', 'disabled')",
"__tf_dispatch__('Trace4', 'disabled')",
"__tf_dispatch__('Trace', 'disabled')"
])
def testDispatchPrecedenceRightToLeft(self):
# We use an API for which dispatch is disabled, so all dispatchers get
# called (since this test checks the order of the dispatcher list).
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
"disabled", None, 5, [4, 0, 1], [2, 3], True)
t = Trace("constant", "t")
t2_1 = Trace2("constant", "t2_1")
t2_2 = Trace2("constant", "t2_2")
t2b = Trace2B("constant", "t2b")
t3 = Trace3("constant", "t3")
t4 = Trace4("constant", "t4")
# Three dispatchable types, none of which is a subclass of the other:
# * precedence is right_to_left (since we set right_to_left=True in the
# PtyonAPIDispatcher constructor). (Note: arguments are scanned
# right-to-left, but the elements of list arguments are still scanned
# left-to-right.)
# * duplicates are removed.
Trace.log.clear()
result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4)
self.assertEqual(result, NotImplemented)
self.assertEqual(Trace.log, [
"__tf_dispatch__('Trace4', 'disabled')",
"__tf_dispatch__('Trace2', 'disabled')",
"__tf_dispatch__('Trace3', 'disabled')"
])
# Subtypes are moved before their base types. (Note: moving subtypes occurs
# *after* we swap the order to be right-to-left; so the dispatch order here
# is not what we'd get by just reversing the final dispatch order if
# right_to_left were false.)
Trace.log.clear()
result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b)
self.assertEqual(result, NotImplemented)
self.assertEqual(Trace.log, [
"__tf_dispatch__('Trace2B', 'disabled')",
"__tf_dispatch__('Trace2', 'disabled')",
"__tf_dispatch__('Trace3', 'disabled')",
"__tf_dispatch__('Trace4', 'disabled')",
"__tf_dispatch__('Trace', 'disabled')"
])
def testDispatchParamOutOfRange(self):
with self.assertRaisesRegex(ValueError, "index out of range"):
_pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5,
[0, 1, 5], [2, 3], True)
with self.assertRaisesRegex(ValueError, "index out of range"):
_pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5,
[0, -3], [2, 3], True)
with self.assertRaisesRegex(ValueError, "index out of range"):
_pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5,
[0, 1], [10, 3], True)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Note: This library is only used by python_api_dispatcher_test. It is
// not meant to be used in other circumstances.
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "tensorflow/python/framework/python_api_dispatcher.h"
namespace py = pybind11;
namespace {
tensorflow::PythonAPIDispatcher MakePythonAPIDispatcher(
const std::string& api_name, py::handle api_func, int num_params,
const std::vector<int>& dispatch_params,
const std::vector<int>& dispatch_list_params, bool right_to_left) {
std::vector<tensorflow::PythonAPIDispatcher::ParamInfo> dispatchable_params;
dispatchable_params.reserve(dispatch_params.size() +
dispatch_list_params.size());
for (int p : dispatch_params) {
dispatchable_params.push_back({p, false});
}
for (int p : dispatch_list_params) {
dispatchable_params.push_back({p, true});
}
auto dispatcher = tensorflow::PythonAPIDispatcher(api_name, api_func.ptr(),
num_params, right_to_left);
if (!dispatcher.Initialize(dispatchable_params)) {
throw py::error_already_set();
}
return dispatcher;
}
py::handle Dispatch(tensorflow::PythonAPIDispatcher* self, py::args args) {
auto result = self->Dispatch(args.ptr());
if (result == nullptr) {
throw py::error_already_set();
} else if (result == Py_NotImplemented) {
Py_INCREF(result);
return result;
} else {
return result;
}
}
} // namespace
PYBIND11_MODULE(_pywrap_python_api_dispatcher, m) {
py::class_<tensorflow::PythonAPIDispatcher>(m, "PythonAPIDispatcher")
.def(py::init(&MakePythonAPIDispatcher))
.def("Dispatch", Dispatch);
}

View File

@ -398,3 +398,5 @@ stream_executor::port::internal_statusor::Helper::Crash
[tensor_handle] # tfe
tensorflow::TensorHandle::Tensor
[python_api_dispatcher] # python_api_dispatcher
tensorflow::PythonAPIDispatcher