Dispatch handler for Python APIs. For background, see the RFC for [TensorFlow Extension Types](https://github.com/tensorflow/community/pull/269).
PiperOrigin-RevId: 338353976 Change-Id: I7f155854973fb496c5c0f43eb24a46aa7418c8d6
This commit is contained in:
parent
0ed710fb76
commit
6f980e4a05
tensorflow
@ -1382,6 +1382,7 @@ py_library(
|
||||
":_pywrap_kernel_registry",
|
||||
":_pywrap_py_exception_registry",
|
||||
":_pywrap_py_func", # TODO(b/142001480): remove once the bug is fixed.
|
||||
":_pywrap_python_api_dispatcher",
|
||||
":_pywrap_python_op_gen",
|
||||
":_pywrap_quantize_training",
|
||||
":_pywrap_stacktrace_handler",
|
||||
@ -1752,6 +1753,45 @@ tf_py_test(
|
||||
tfrt_enabled = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_api_dispatcher",
|
||||
srcs = ["framework/python_api_dispatcher.cc"],
|
||||
hdrs = ["framework/python_api_dispatcher.h"],
|
||||
deps = [
|
||||
":cpp_python_util",
|
||||
":safe_pyobject_ptr",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
# Note: this target is only used by python_api_dispatcher_test.
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_python_api_dispatcher",
|
||||
# testonly = True,
|
||||
srcs = ["framework/python_api_dispatcher_wrapper.cc"],
|
||||
hdrs = ["framework/python_api_dispatcher.h"],
|
||||
module_name = "_pywrap_python_api_dispatcher",
|
||||
deps = [
|
||||
":safe_pyobject_ptr_required_hdrs",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "python_api_dispatcher_test",
|
||||
srcs = ["framework/python_api_dispatcher_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":_pywrap_python_api_dispatcher",
|
||||
":client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "framework_ops", # "ops" is already the name of a deprecated target
|
||||
srcs = ["framework/ops.py"],
|
||||
@ -6049,6 +6089,7 @@ pywrap_tensorflow_macro(
|
||||
":pybind11_lib",
|
||||
":pybind11_status",
|
||||
":pybind11_proto",
|
||||
":python_api_dispatcher",
|
||||
":python_op_gen",
|
||||
":safe_pyobject_ptr",
|
||||
":tf_session_helper",
|
||||
@ -6121,6 +6162,7 @@ filegroup(
|
||||
":numpy_lib", # checkpoint_reader
|
||||
":py_exception_registry", # py_exception_registry
|
||||
":py_func_lib", # py_func
|
||||
":python_api_dispatcher", # python_api_dispatcher
|
||||
":python_op_gen", # python_op_gen
|
||||
":safe_ptr", # checkpoint_reader
|
||||
"//tensorflow/c:checkpoint_reader", # checkpoint_reader
|
||||
|
220
tensorflow/python/framework/python_api_dispatcher.cc
Normal file
220
tensorflow/python/framework/python_api_dispatcher.cc
Normal file
@ -0,0 +1,220 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/framework/python_api_dispatcher.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
#include "tensorflow/python/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using ParamInfo = PythonAPIDispatcher::ParamInfo;
|
||||
|
||||
// List of python types to check for dispatch. In most cases, this vector
|
||||
// will have size zero or one; and sizes greater than 3 should be rare.
|
||||
using TypeList = absl::InlinedVector<PyTypeObject*, 3>;
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the __tf__dispatch__ attribute of `obj`.
|
||||
Safe_PyObjectPtr GetAttr_TFDispatch(PyObject* obj) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
// Python 2.x:
|
||||
static PyObject* attr = PyString_InternFromString("__tf_dispatch__");
|
||||
#else
|
||||
// Python 3.x:
|
||||
static PyObject* attr = PyUnicode_InternFromString("__tf_dispatch__");
|
||||
#endif
|
||||
return Safe_PyObjectPtr(PyObject_GetAttr(obj, attr));
|
||||
}
|
||||
|
||||
// Searches `params` for dispatchable types, and returns a vector of borrowed
|
||||
// references to those types. Removes consecutive duplicates (i.e., if a
|
||||
// dispatchable parameter has the same type as the previously encountered
|
||||
// dispatcahble parameter, then it's type is not added again), so the result
|
||||
// will usually have a length of zero or one; but in the general case, it may be
|
||||
// longer, and may contain (nonconsecutive) duplicates.
|
||||
//
|
||||
// Assumes that `params` is a tuple, and that all parameter indices in
|
||||
// `dispatch_params` and `dispatch_list_params` are valid.
|
||||
TypeList FindDispatchTypes(PyObject* params,
|
||||
const std::vector<ParamInfo>& dispatchable_params) {
|
||||
TypeList dispatch_types;
|
||||
for (const auto& param : dispatchable_params) {
|
||||
DCHECK_GE(param.index, 0);
|
||||
DCHECK_LT(param.index, PyTuple_GET_SIZE(params));
|
||||
PyObject* value = PyTuple_GET_ITEM(params, param.index);
|
||||
if (param.is_list) {
|
||||
DCHECK(PyList_Check(value));
|
||||
Py_ssize_t num_items = PyList_Size(value);
|
||||
for (Py_ssize_t i = 0; i < num_items; ++i) {
|
||||
PyObject* item = PyList_GET_ITEM(value, i);
|
||||
// TODO(b/164980194) Consider changing IsDispatchable to not use a
|
||||
// cache. This may impact efficiency (needs to be measured), but would
|
||||
// allow us to support monkey-patching classes to be dispatchable.
|
||||
if (swig::IsDispatchable(item)) {
|
||||
if (dispatch_types.empty() ||
|
||||
value->ob_type != dispatch_types.back()) {
|
||||
dispatch_types.push_back(item->ob_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (swig::IsDispatchable(value)) {
|
||||
if (dispatch_types.empty() || value->ob_type != dispatch_types.back()) {
|
||||
dispatch_types.push_back(value->ob_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dispatch_types;
|
||||
}
|
||||
|
||||
// Removes duplicates from `dispatch_types`, and moves any subtypes to
|
||||
// before their supertypes. Note: this method is only called when
|
||||
// `dispatch_types.size() > 1`.
|
||||
void SortDispatchTypes(TypeList& dispatch_types) {
|
||||
// Remove duplicates. Note: this is O(n^2) in the number of dispatchable
|
||||
// types, but we expect this number to be very small in almost every case
|
||||
// (usually zero, sometimes one, and rarely larger than two).
|
||||
for (int i = 0; i < dispatch_types.size() - 1; ++i) {
|
||||
if (dispatch_types[i] == nullptr) continue;
|
||||
for (int j = i + 1; j < dispatch_types.size(); ++j) {
|
||||
if (dispatch_types[i] == dispatch_types[j]) {
|
||||
dispatch_types[j] = nullptr; // mark duplicate
|
||||
}
|
||||
}
|
||||
}
|
||||
dispatch_types.erase(
|
||||
std::remove_if(dispatch_types.begin(), dispatch_types.end(),
|
||||
[](PyTypeObject* t) { return t == nullptr; }),
|
||||
dispatch_types.end());
|
||||
|
||||
// Move subclasses before superclasses. As above, this is O(n^2), but we
|
||||
// expect n to be small.
|
||||
TypeList sorted;
|
||||
TypeList subtypes;
|
||||
for (int i = 0; i < dispatch_types.size(); ++i) {
|
||||
if (dispatch_types[i] == nullptr) continue;
|
||||
subtypes.clear();
|
||||
for (int j = i + 1; j < dispatch_types.size(); ++j) {
|
||||
if (dispatch_types[j] == nullptr) continue;
|
||||
if (PyType_IsSubtype(dispatch_types[j], dispatch_types[i])) {
|
||||
subtypes.push_back(dispatch_types[j]);
|
||||
dispatch_types[j] = nullptr; // mark as already added.
|
||||
}
|
||||
}
|
||||
if (!subtypes.empty()) {
|
||||
std::sort(subtypes.begin(), subtypes.end(), PyType_IsSubtype);
|
||||
sorted.insert(sorted.end(), subtypes.begin(), subtypes.end());
|
||||
}
|
||||
sorted.push_back(dispatch_types[i]);
|
||||
}
|
||||
DCHECK_EQ(dispatch_types.size(), sorted.size());
|
||||
dispatch_types.swap(sorted);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PythonAPIDispatcher::PythonAPIDispatcher(const std::string& api_name,
|
||||
PyObject* api_func, int num_params,
|
||||
bool right_to_left)
|
||||
: api_name_(PyUnicode_FromStringAndSize(api_name.c_str(), api_name.size())),
|
||||
api_func_(api_func),
|
||||
num_params_(num_params),
|
||||
right_to_left_(right_to_left) {
|
||||
Py_INCREF(api_func);
|
||||
}
|
||||
|
||||
bool PythonAPIDispatcher::Initialize(
|
||||
std::vector<ParamInfo> dispatchable_params) {
|
||||
dispatchable_params_.swap(dispatchable_params);
|
||||
std::sort(dispatchable_params_.begin(), dispatchable_params_.end(),
|
||||
[](const ParamInfo& a, const ParamInfo& b) -> bool {
|
||||
return a.index < b.index;
|
||||
});
|
||||
if (right_to_left_) {
|
||||
std::reverse(dispatchable_params_.begin(), dispatchable_params_.end());
|
||||
}
|
||||
|
||||
for (const auto& p : dispatchable_params_) {
|
||||
if (p.index < 0 || p.index >= num_params_) {
|
||||
PyErr_SetString(
|
||||
PyExc_ValueError,
|
||||
absl::StrCat("PythonAPIDispatcher: dispatchable parameter index out ",
|
||||
"of range: ", p.index, " not in [0, ", num_params_, ")")
|
||||
.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
PyObject* PythonAPIDispatcher::Dispatch(PyObject* params) const {
|
||||
DCHECK(PyTuple_Check(params));
|
||||
|
||||
// TODO(b/164980194) Consider removing this check, if the caller is also
|
||||
// checking/guaranteeing it (once dispatch has been integrated w/ the Python
|
||||
// API handlers).
|
||||
if (num_params_ != PyTuple_Size(params)) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
// Python 2.x:
|
||||
Safe_PyObjectPtr api_name_str(PyUnicode_AsUTF8String(api_name_.get()));
|
||||
if (!api_name_str) return nullptr;
|
||||
const char* api_name = PyString_AsString(api_name_str.get());
|
||||
#else
|
||||
// Python 3.x:
|
||||
const char* api_name = PyUnicode_AsUTF8AndSize(api_name_.get(), nullptr);
|
||||
#endif
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
absl::StrCat(api_name ? api_name : "unknown PythonAPIDispatcher",
|
||||
" expected ", num_params_, " parameters, but got ",
|
||||
PyTuple_Size(params))
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TypeList dispatch_types = FindDispatchTypes(params, dispatchable_params_);
|
||||
|
||||
if (dispatch_types.empty()) {
|
||||
return Py_NotImplemented;
|
||||
}
|
||||
|
||||
if (dispatch_types.size() > 1) {
|
||||
SortDispatchTypes(dispatch_types);
|
||||
}
|
||||
|
||||
for (PyTypeObject* dispatch_type : dispatch_types) {
|
||||
Safe_PyObjectPtr dispatcher =
|
||||
GetAttr_TFDispatch(reinterpret_cast<PyObject*>(dispatch_type));
|
||||
if (!dispatcher) return nullptr;
|
||||
PyObject* result = PyObject_CallFunctionObjArgs(
|
||||
dispatcher.get(), api_name_.get(), api_func_.get(), params, nullptr);
|
||||
if (result != Py_NotImplemented) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
return Py_NotImplemented;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
131
tensorflow/python/framework/python_api_dispatcher.h
Normal file
131
tensorflow/python/framework/python_api_dispatcher.h
Normal file
@ -0,0 +1,131 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
|
||||
#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Dispatch handler for Python APIs.
|
||||
//
|
||||
// A separate PythonAPIDispatcher object is created for each Python API, and
|
||||
// keeps track of which parameters should be checked for dispatch.
|
||||
//
|
||||
// When PythonAPIDispatcher::Dispatch() is called with a tuple of
|
||||
// canonicalized parameters, it checks the indicated parameters' values for
|
||||
// `__tf_dispatch__` methods. If found, then this method is called with the
|
||||
// following arguments: `__tf_dispatch__(api_name, api_func, canon_args)`,
|
||||
// where:
|
||||
//
|
||||
// * `api_name` is the fully-qualified name of the python API (e.g.,
|
||||
// `"tf.math.sum"`).
|
||||
// * `api_func` is the function that implements the APIs for `Tensor` inputs.
|
||||
// * `canon_args` is the canonicalized argument list.
|
||||
//
|
||||
class PythonAPIDispatcher {
|
||||
public:
|
||||
// Information about an API parameter that supports dispatch. `index` is the
|
||||
// parameter's index in the canonicalized parameter list, and `is_list` is
|
||||
// true if the parameter expects a list of values (e.g. the `values` parameter
|
||||
// to `tf.concat`).
|
||||
struct ParamInfo {
|
||||
int index;
|
||||
bool is_list;
|
||||
};
|
||||
|
||||
// Constructs a PythonAPIDispatcher.
|
||||
//
|
||||
// Args:
|
||||
// api_name: The fully qualified name of the API handled by this dispatcher.
|
||||
// api_func: The python function for which implements the API for `Tensor`
|
||||
// inputs.
|
||||
// num_params: The number of canonical parameters that the API expects.
|
||||
// right_to_left: If true, then the normal precedence rules (in which
|
||||
// dispatchers are tried from left-to-right) are changed to try
|
||||
// dispatchers from right-to-left instead. This is used for operations
|
||||
// such as `__radd__`, where the normal parameter order is reversed.
|
||||
PythonAPIDispatcher(const std::string& api_name, PyObject* api_func,
|
||||
int num_params, bool right_to_left = false);
|
||||
|
||||
// Initiliaze this PythonAPIDispatcher with information about which parameters
|
||||
// support dispatch. Returns true on success, or sets a python exception and
|
||||
// returns false on error.
|
||||
bool Initialize(std::vector<ParamInfo> dispatchable_params);
|
||||
|
||||
// Checks if any of the dispatchable parameters have a `__tf_dispatch__`
|
||||
// method, and if so, calls them. In particular, this method:
|
||||
//
|
||||
// 1. Constructs an ordered list of dispatchable types.
|
||||
//
|
||||
// * Checks each argument that support dispatch to see if its value(s) have
|
||||
// a `__tf_dispatch__` method.
|
||||
// * Arguments are checked left-to-right unless `right_to_left` was set to
|
||||
// True in the constructor. *Within* a list-valued parameter, elements
|
||||
// are always checked left-to-right (even if `right_to_left` is True).
|
||||
// * Duplicate types are removed (only the first occurrence of each type is
|
||||
// kept).
|
||||
// * If any type `T_sub` is a subtype of another type `T_super`, but occurs
|
||||
// after `T_super` in the list of dispatchable types, then it is moved to
|
||||
// just before `T_super`.
|
||||
//
|
||||
// 2. Tries calling each of the dispatchable types' `__tf_dispatch__` methods.
|
||||
//
|
||||
// * Dispatch methods are called with the following arguments:
|
||||
// `__tf_dispatch__(api_name, api_func, canon_args)`
|
||||
// * Dispatch methods are tried in the order described above.
|
||||
// * If a dispatch method returns a value, then `Dispatch()` returns a
|
||||
// new reference to that value.
|
||||
// * If a dispatch method raises an exception, then `Dispatch()` returns
|
||||
// null (i.e., propogates the exception).
|
||||
// * If a dispatch method returns `NotImplemented`, then the dispatcher
|
||||
// moves on to the next type.
|
||||
//
|
||||
// 3. If no dispatchers for found, or all dispatchers returned
|
||||
// `NotImplemented', then the dispatcher returns a *borrowed* reference
|
||||
// to `Py_NotImplemented`.
|
||||
//
|
||||
// Args:
|
||||
// params: A `PyTuple` containing the canonicalized parameters to the API.
|
||||
// All `POSITIONAL_OR_KEYWORD` arguments must be converted to positional
|
||||
// arguments (`KEYWORD_ONLY` arguments are not currently supported). Any
|
||||
// dispatchable parameter with `is_list=True` must have been converted to
|
||||
// `PyList`.
|
||||
//
|
||||
// Returns:
|
||||
// * If a `__tf_dispatch__` handler successfully handled the API:
|
||||
// Returns a *new* reference to the handler's return value.
|
||||
// * If no handler was found, or all handlers returned NotImplemented:
|
||||
// Returns a *borrowed* reference to `Py_NotImplemented`.
|
||||
// * On error: Sets an exception and returns `nullptr`.
|
||||
PyObject* Dispatch(PyObject* params) const;
|
||||
|
||||
private:
|
||||
Safe_PyObjectPtr api_name_;
|
||||
Safe_PyObjectPtr api_func_;
|
||||
int num_params_;
|
||||
std::vector<ParamInfo> dispatchable_params_;
|
||||
bool right_to_left_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
|
244
tensorflow/python/framework/python_api_dispatcher_test.py
Normal file
244
tensorflow/python/framework/python_api_dispatcher_test.py
Normal file
@ -0,0 +1,244 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.framework.python_api_dispatcher."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import _pywrap_python_api_dispatcher
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class Trace(object):
|
||||
"""A dispatchable type that builds traces of ops it's called with."""
|
||||
|
||||
log = []
|
||||
|
||||
def __init__(self, api_name, *args):
|
||||
self.api_name = api_name
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def __tf_dispatch__(cls, api_name, api_func, args):
|
||||
Trace.log.append("__tf_dispatch__%s" % ((cls.__name__, api_name),))
|
||||
if "disabled" in str(args) or api_name == "disabled":
|
||||
return NotImplemented
|
||||
del api_func # not used
|
||||
return cls(api_name, *args)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s%s" % (type(self).__name__, (self.api_name,) + self.args)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other) and self.api_name == other.api_name and
|
||||
self.args == other.args)
|
||||
|
||||
|
||||
class Trace2(Trace):
|
||||
pass
|
||||
|
||||
|
||||
class Trace2B(Trace2):
|
||||
pass
|
||||
|
||||
|
||||
class Trace3(Trace):
|
||||
pass
|
||||
|
||||
|
||||
class Trace4(Trace):
|
||||
pass
|
||||
|
||||
|
||||
class WeightedTensor(object):
|
||||
|
||||
def __init__(self, tensor, weight):
|
||||
self.tensor = ops.convert_to_tensor(tensor)
|
||||
self.weight = weight # Python float
|
||||
|
||||
@classmethod
|
||||
def __tf_dispatch__(cls, api_name, api_func, args):
|
||||
del api_name # unused
|
||||
weights = [arg.weight for arg in args if isinstance(arg, WeightedTensor)]
|
||||
tensors = [
|
||||
arg.tensor if isinstance(arg, WeightedTensor) else arg for arg in args
|
||||
]
|
||||
tensor_result = api_func(*tensors)
|
||||
avg_weight = sum(weights) / len(weights)
|
||||
return cls(tensor_result, avg_weight)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class PythonAPIDispatcherTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testNoDispatchableTypes(self):
|
||||
add_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
|
||||
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
|
||||
self.assertEqual(add_dispatcher.Dispatch(1, 2), NotImplemented)
|
||||
|
||||
concat_dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
|
||||
"tf.concat", array_ops.concat, 2, [1], [0], False)
|
||||
self.assertEqual(concat_dispatcher.Dispatch([1], 0), NotImplemented)
|
||||
|
||||
def testSimpleDispatchWithTrace(self):
|
||||
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
|
||||
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
|
||||
x = 5
|
||||
y = Trace("constant", "y")
|
||||
z = Trace("constant", "z")
|
||||
|
||||
Trace.log.clear()
|
||||
self.assertEqual(dispatcher.Dispatch(x, y), Trace("tf.math.add", x, y))
|
||||
self.assertEqual(dispatcher.Dispatch(y, x), Trace("tf.math.add", y, x))
|
||||
self.assertEqual(dispatcher.Dispatch(y, z), Trace("tf.math.add", y, z))
|
||||
self.assertEqual(Trace.log, [
|
||||
"__tf_dispatch__('Trace', 'tf.math.add')",
|
||||
"__tf_dispatch__('Trace', 'tf.math.add')",
|
||||
"__tf_dispatch__('Trace', 'tf.math.add')"
|
||||
])
|
||||
|
||||
def testDispatcherReturnsNotImplemented(self):
|
||||
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
|
||||
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
|
||||
x = 5
|
||||
y = Trace("constant", "disabled")
|
||||
z = Trace("constant", "z")
|
||||
|
||||
self.assertEqual(dispatcher.Dispatch(x, y), NotImplemented)
|
||||
self.assertEqual(dispatcher.Dispatch(y, x), NotImplemented)
|
||||
self.assertEqual(dispatcher.Dispatch(y, z), NotImplemented)
|
||||
self.assertEqual(dispatcher.Dispatch(z, z), Trace("tf.math.add", z, z))
|
||||
|
||||
def testSimpleDispatchWithWeightedTensor(self):
|
||||
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
|
||||
"tf.math.add", math_ops.add, 2, [0, 1], [], False)
|
||||
x = 5
|
||||
y = WeightedTensor([1, 2, 3], 0.6)
|
||||
z = WeightedTensor([10, 20, 30], 0.2)
|
||||
|
||||
x_plus_y = dispatcher.Dispatch(x, y)
|
||||
y_plus_x = dispatcher.Dispatch(y, x)
|
||||
y_plus_z = dispatcher.Dispatch(y, z)
|
||||
|
||||
self.assertAllEqual(x_plus_y.tensor, [6, 7, 8])
|
||||
self.assertAllEqual(y_plus_x.tensor, [6, 7, 8])
|
||||
self.assertAllEqual(y_plus_z.tensor, [11, 22, 33])
|
||||
|
||||
self.assertEqual(x_plus_y.weight, 0.6)
|
||||
self.assertEqual(y_plus_x.weight, 0.6)
|
||||
self.assertEqual(y_plus_z.weight, 0.4)
|
||||
|
||||
def testDispatchPrecedence(self):
|
||||
# We use an API for which dispatch is disabled, so all dispatchers get
|
||||
# called (since this test checks the order of the dispatcher list).
|
||||
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
|
||||
"disabled", None, 5, [0, 1, 4], [2, 3], False)
|
||||
|
||||
t = Trace("constant", "t")
|
||||
t2_1 = Trace2("constant", "t2_1")
|
||||
t2_2 = Trace2("constant", "t2_2")
|
||||
t2b = Trace2B("constant", "t2b")
|
||||
t3 = Trace3("constant", "t3")
|
||||
t4 = Trace4("constant", "t4")
|
||||
|
||||
# Three dispatchable types, none of which is a subclass of the other:
|
||||
# * precedence is left-to-right.
|
||||
# * duplicates are removed.
|
||||
Trace.log.clear()
|
||||
result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4)
|
||||
self.assertEqual(result, NotImplemented)
|
||||
self.assertEqual(Trace.log, [
|
||||
"__tf_dispatch__('Trace2', 'disabled')",
|
||||
"__tf_dispatch__('Trace3', 'disabled')",
|
||||
"__tf_dispatch__('Trace4', 'disabled')"
|
||||
])
|
||||
|
||||
# Subtypes are moved before their base types.
|
||||
Trace.log.clear()
|
||||
result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b)
|
||||
self.assertEqual(result, NotImplemented)
|
||||
self.assertEqual(Trace.log, [
|
||||
"__tf_dispatch__('Trace2B', 'disabled')",
|
||||
"__tf_dispatch__('Trace2', 'disabled')",
|
||||
"__tf_dispatch__('Trace3', 'disabled')",
|
||||
"__tf_dispatch__('Trace4', 'disabled')",
|
||||
"__tf_dispatch__('Trace', 'disabled')"
|
||||
])
|
||||
|
||||
def testDispatchPrecedenceRightToLeft(self):
|
||||
# We use an API for which dispatch is disabled, so all dispatchers get
|
||||
# called (since this test checks the order of the dispatcher list).
|
||||
dispatcher = _pywrap_python_api_dispatcher.PythonAPIDispatcher(
|
||||
"disabled", None, 5, [4, 0, 1], [2, 3], True)
|
||||
|
||||
t = Trace("constant", "t")
|
||||
t2_1 = Trace2("constant", "t2_1")
|
||||
t2_2 = Trace2("constant", "t2_2")
|
||||
t2b = Trace2B("constant", "t2b")
|
||||
t3 = Trace3("constant", "t3")
|
||||
t4 = Trace4("constant", "t4")
|
||||
|
||||
# Three dispatchable types, none of which is a subclass of the other:
|
||||
# * precedence is right_to_left (since we set right_to_left=True in the
|
||||
# PtyonAPIDispatcher constructor). (Note: arguments are scanned
|
||||
# right-to-left, but the elements of list arguments are still scanned
|
||||
# left-to-right.)
|
||||
# * duplicates are removed.
|
||||
Trace.log.clear()
|
||||
result = dispatcher.Dispatch(t2_1, t3, [], [t2_2, t3], t4)
|
||||
self.assertEqual(result, NotImplemented)
|
||||
self.assertEqual(Trace.log, [
|
||||
"__tf_dispatch__('Trace4', 'disabled')",
|
||||
"__tf_dispatch__('Trace2', 'disabled')",
|
||||
"__tf_dispatch__('Trace3', 'disabled')"
|
||||
])
|
||||
|
||||
# Subtypes are moved before their base types. (Note: moving subtypes occurs
|
||||
# *after* we swap the order to be right-to-left; so the dispatch order here
|
||||
# is not what we'd get by just reversing the final dispatch order if
|
||||
# right_to_left were false.)
|
||||
Trace.log.clear()
|
||||
result = dispatcher.Dispatch(t2_1, t3, [t], [t2_2, t, t3, t4], t2b)
|
||||
self.assertEqual(result, NotImplemented)
|
||||
self.assertEqual(Trace.log, [
|
||||
"__tf_dispatch__('Trace2B', 'disabled')",
|
||||
"__tf_dispatch__('Trace2', 'disabled')",
|
||||
"__tf_dispatch__('Trace3', 'disabled')",
|
||||
"__tf_dispatch__('Trace4', 'disabled')",
|
||||
"__tf_dispatch__('Trace', 'disabled')"
|
||||
])
|
||||
|
||||
def testDispatchParamOutOfRange(self):
|
||||
with self.assertRaisesRegex(ValueError, "index out of range"):
|
||||
_pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5,
|
||||
[0, 1, 5], [2, 3], True)
|
||||
with self.assertRaisesRegex(ValueError, "index out of range"):
|
||||
_pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5,
|
||||
[0, -3], [2, 3], True)
|
||||
with self.assertRaisesRegex(ValueError, "index out of range"):
|
||||
_pywrap_python_api_dispatcher.PythonAPIDispatcher("some_api", None, 5,
|
||||
[0, 1], [10, 3], True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
67
tensorflow/python/framework/python_api_dispatcher_wrapper.cc
Normal file
67
tensorflow/python/framework/python_api_dispatcher_wrapper.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Note: This library is only used by python_api_dispatcher_test. It is
|
||||
// not meant to be used in other circumstances.
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/framework/python_api_dispatcher.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
|
||||
tensorflow::PythonAPIDispatcher MakePythonAPIDispatcher(
|
||||
const std::string& api_name, py::handle api_func, int num_params,
|
||||
const std::vector<int>& dispatch_params,
|
||||
const std::vector<int>& dispatch_list_params, bool right_to_left) {
|
||||
std::vector<tensorflow::PythonAPIDispatcher::ParamInfo> dispatchable_params;
|
||||
dispatchable_params.reserve(dispatch_params.size() +
|
||||
dispatch_list_params.size());
|
||||
for (int p : dispatch_params) {
|
||||
dispatchable_params.push_back({p, false});
|
||||
}
|
||||
for (int p : dispatch_list_params) {
|
||||
dispatchable_params.push_back({p, true});
|
||||
}
|
||||
|
||||
auto dispatcher = tensorflow::PythonAPIDispatcher(api_name, api_func.ptr(),
|
||||
num_params, right_to_left);
|
||||
if (!dispatcher.Initialize(dispatchable_params)) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return dispatcher;
|
||||
}
|
||||
|
||||
py::handle Dispatch(tensorflow::PythonAPIDispatcher* self, py::args args) {
|
||||
auto result = self->Dispatch(args.ptr());
|
||||
if (result == nullptr) {
|
||||
throw py::error_already_set();
|
||||
} else if (result == Py_NotImplemented) {
|
||||
Py_INCREF(result);
|
||||
return result;
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(_pywrap_python_api_dispatcher, m) {
|
||||
py::class_<tensorflow::PythonAPIDispatcher>(m, "PythonAPIDispatcher")
|
||||
.def(py::init(&MakePythonAPIDispatcher))
|
||||
.def("Dispatch", Dispatch);
|
||||
}
|
@ -398,3 +398,5 @@ stream_executor::port::internal_statusor::Helper::Crash
|
||||
[tensor_handle] # tfe
|
||||
tensorflow::TensorHandle::Tensor
|
||||
|
||||
[python_api_dispatcher] # python_api_dispatcher
|
||||
tensorflow::PythonAPIDispatcher
|
||||
|
Loading…
Reference in New Issue
Block a user