Implement fast C++ Python function parameter canonicalizer util.
It flattens `args` and `kwargs` according to the specified function spec, for example: def matmul( a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None ): args = (2, 3, True) kwargs = {'adjoint_a': True, 'name': 'my_matmul'} Then using `FunctionParameterCanonicalizer.Canonicalize(...)`, users can get the following list of PyObject*. [2, 3, True, False, True, False, False, False, 'my_matmul'] PiperOrigin-RevId: 331560836 Change-Id: Ib6380f81659909a2fe9422a25219b3f1cae85224
This commit is contained in:
parent
021440d9d9
commit
5b2ec7d103
@ -5532,6 +5532,44 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_parameter_canonicalizer",
|
||||
srcs = ["util/function_parameter_canonicalizer.cc"],
|
||||
hdrs = ["util/function_parameter_canonicalizer.h"],
|
||||
deps = [
|
||||
":py_util",
|
||||
":safe_pyobject_ptr",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_function_parameter_canonicalizer_binding_for_test",
|
||||
testonly = True,
|
||||
srcs = ["util/function_parameter_canonicalizer_binding_for_test.cc"],
|
||||
module_name = "_function_parameter_canonicalizer_binding_for_test",
|
||||
deps = [
|
||||
":function_parameter_canonicalizer",
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "function_parameter_canonicalizer_test",
|
||||
srcs = ["util/function_parameter_canonicalizer_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":_function_parameter_canonicalizer_binding_for_test",
|
||||
":client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "global_test_configuration",
|
||||
deps = if_mlir(["//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration"]) +
|
||||
|
134
tensorflow/python/util/function_parameter_canonicalizer.cc
Normal file
134
tensorflow/python/util/function_parameter_canonicalizer.cc
Normal file
@ -0,0 +1,134 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/util/function_parameter_canonicalizer.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/python/lib/core/py_util.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
FunctionParameterCanonicalizer::FunctionParameterCanonicalizer(
|
||||
absl::Span<const char*> arg_names, absl::Span<PyObject*> defaults)
|
||||
: positional_args_size_(arg_names.size() - defaults.size()) {
|
||||
DCheckPyGilState();
|
||||
DCHECK_GE(positional_args_size_, 0);
|
||||
|
||||
interned_arg_names_.reserve(arg_names.size());
|
||||
for (const char* obj : arg_names)
|
||||
interned_arg_names_.emplace_back(PyUnicode_InternFromString(obj));
|
||||
|
||||
DCHECK(AreInternedArgNamesUnique());
|
||||
|
||||
for (PyObject* obj : defaults) Py_INCREF(obj);
|
||||
defaults_ = std::vector<Safe_PyObjectPtr>(defaults.begin(), defaults.end());
|
||||
}
|
||||
|
||||
bool FunctionParameterCanonicalizer::Canonicalize(
|
||||
PyObject* args, PyObject* kwargs, absl::Span<PyObject*> result) {
|
||||
// TODO(kkb): Closely follow `Python/ceval.c`'s logic and error handling.
|
||||
|
||||
DCheckPyGilState();
|
||||
DCHECK(PyTuple_CheckExact(args));
|
||||
DCHECK(PyDict_CheckExact(kwargs));
|
||||
DCHECK_EQ(result.size(), interned_arg_names_.size());
|
||||
|
||||
const int args_size = Py_SIZE(args);
|
||||
int remaining_positional_args_count = positional_args_size_ - args_size;
|
||||
|
||||
// Check if the number of input arguments are too many.
|
||||
if (TF_PREDICT_FALSE(args_size > interned_arg_names_.size())) {
|
||||
// TODO(kkb): Also report the actual numbers.
|
||||
PyErr_SetString(PyExc_TypeError, "Too many arguments were given");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fill positional arguments.
|
||||
for (int i = 0; i < args_size; ++i) result[i] = PyTuple_GET_ITEM(args, i);
|
||||
|
||||
// Fill default arguments.
|
||||
for (int i = args_size; i < interned_arg_names_.size(); ++i)
|
||||
result[i] = defaults_[i - positional_args_size_].get();
|
||||
|
||||
// Fill keyword arguments.
|
||||
if (kwargs != nullptr) {
|
||||
PyObject *key, *value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(kwargs, &pos, &key, &value)) {
|
||||
std::size_t index = InternedArgNameLinearSearch(key);
|
||||
|
||||
// Check if key object(argument name) was found in the pre-built intern
|
||||
// string table.
|
||||
if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
|
||||
// `key` might not be an interend string, so get the interned string
|
||||
// and try again.
|
||||
PyUnicode_InternInPlace(&key);
|
||||
index = InternedArgNameLinearSearch(key);
|
||||
|
||||
// Stil not found, then return an error.
|
||||
if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
|
||||
PyErr_Format(PyExc_TypeError,
|
||||
"Got an unexpected keyword argument '%S'", key);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the keyword argument overlaps with positional arguments.
|
||||
if (TF_PREDICT_FALSE(index < args_size)) {
|
||||
PyErr_Format(PyExc_TypeError, "Got multiple values for argument '%S'",
|
||||
key);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (TF_PREDICT_FALSE(index < positional_args_size_))
|
||||
--remaining_positional_args_count;
|
||||
|
||||
result[index] = value;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all the arguments are filled.
|
||||
// Example failure, not enough number of arguments passed: `matmul(x)`
|
||||
if (TF_PREDICT_FALSE(remaining_positional_args_count > 0)) {
|
||||
// TODO(kkb): Report what arguments are missing.
|
||||
PyErr_SetString(PyExc_TypeError, "Missing required positional argument");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ABSL_MUST_USE_RESULT
|
||||
ABSL_ATTRIBUTE_HOT
|
||||
inline std::size_t FunctionParameterCanonicalizer::InternedArgNameLinearSearch(
|
||||
PyObject* name) {
|
||||
std::size_t result = interned_arg_names_.size();
|
||||
|
||||
for (std::size_t i = 0; i < interned_arg_names_.size(); ++i)
|
||||
if (TF_PREDICT_FALSE(name == interned_arg_names_[i].get())) return i;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool FunctionParameterCanonicalizer::AreInternedArgNamesUnique() {
|
||||
absl::flat_hash_set<PyObject*> interned_arg_names_set;
|
||||
for (const Safe_PyObjectPtr& obj : interned_arg_names_)
|
||||
interned_arg_names_set.emplace(obj.get());
|
||||
|
||||
return interned_arg_names_set.size() == interned_arg_names_.size();
|
||||
}
|
||||
} // namespace tensorflow
|
73
tensorflow/python/util/function_parameter_canonicalizer.h
Normal file
73
tensorflow/python/util/function_parameter_canonicalizer.h
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PYTHON_UTIL_FUNCTION_PARAMETER_CANONICALIZER_H_
|
||||
#define TENSORFLOW_PYTHON_UTIL_FUNCTION_PARAMETER_CANONICALIZER_H_
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A class that Canonicalizes Python arg & kwargs parameters.
|
||||
class FunctionParameterCanonicalizer {
|
||||
public:
|
||||
// `arg_names` is a list of argument names, and `defaults` is default PyObject
|
||||
// instances for arguments. `default` is aligned to the end.
|
||||
FunctionParameterCanonicalizer(absl::Span<const char*> arg_names,
|
||||
absl::Span<PyObject*> defaults);
|
||||
|
||||
// Returns the total number of arguments.
|
||||
ABSL_MUST_USE_RESULT
|
||||
int GetArgSize() const { return interned_arg_names_.size(); }
|
||||
|
||||
// Canonicalizes `args` and `kwargs` by the spec specified at construction.
|
||||
// It's written to `result`. Returns `true` if Canonicalization was
|
||||
// successful, and `false` otherwise. When it fails, it also sets CPython
|
||||
// error status.
|
||||
// This function does not update reference counter of any Python objects.
|
||||
// `PyObject*`s in `result` are borrowed references from `args`, `kwargs`, and
|
||||
// possibly `defaults_`, and will be only valid if `args` and `kwargs` are
|
||||
// still alive.
|
||||
ABSL_MUST_USE_RESULT
|
||||
ABSL_ATTRIBUTE_HOT
|
||||
bool Canonicalize(PyObject* args, PyObject* kwargs,
|
||||
absl::Span<PyObject*> result);
|
||||
|
||||
private:
|
||||
// Simple linear search of `name` in `interned_arg_names`. If found, returns
|
||||
// the index. If not found, returns `interned_arg_names.size()`.
|
||||
ABSL_MUST_USE_RESULT
|
||||
ABSL_ATTRIBUTE_HOT
|
||||
std::size_t InternedArgNameLinearSearch(PyObject* name);
|
||||
|
||||
// Check if `interned_arg_names_` is unique.
|
||||
bool AreInternedArgNamesUnique();
|
||||
|
||||
// TODO(kkb): Use one `std::vector` and two `absl:Span`s instead to improve
|
||||
// cache locality.
|
||||
std::vector<Safe_PyObjectPtr> interned_arg_names_;
|
||||
std::vector<Safe_PyObjectPtr> defaults_;
|
||||
const int positional_args_size_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_UTIL_FUNCTION_PARAMETER_CANONICALIZER_H_
|
@ -0,0 +1,71 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/python/util/function_parameter_canonicalizer.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
class FunctionParameterCanonicalizerWrapper {
|
||||
public:
|
||||
FunctionParameterCanonicalizerWrapper(absl::Span<const char*> arg_names,
|
||||
absl::Span<PyObject*> defaults)
|
||||
: function_parameter_canonicalizer_(arg_names, defaults) {}
|
||||
|
||||
tensorflow::FunctionParameterCanonicalizer function_parameter_canonicalizer_;
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(_function_parameter_canonicalizer_binding_for_test, m) {
|
||||
py::class_<FunctionParameterCanonicalizerWrapper>(
|
||||
m, "FunctionParameterCanonicalizer")
|
||||
.def(py::init([](std::vector<std::string> arg_names, py::tuple defaults) {
|
||||
std::vector<const char*> arg_names_c_str;
|
||||
for (const std::string& name : arg_names)
|
||||
arg_names_c_str.emplace_back(name.c_str());
|
||||
|
||||
tensorflow::Safe_PyObjectPtr defaults_fast(
|
||||
PySequence_Fast(defaults.ptr(), "Expected tuple"));
|
||||
if (!defaults) throw py::error_already_set();
|
||||
PyObject** default_items = PySequence_Fast_ITEMS(defaults_fast.get());
|
||||
return new FunctionParameterCanonicalizerWrapper(
|
||||
absl::MakeSpan(arg_names_c_str),
|
||||
absl::MakeSpan(default_items,
|
||||
PySequence_Fast_GET_SIZE(defaults_fast.get())));
|
||||
}))
|
||||
.def("canonicalize", [](FunctionParameterCanonicalizerWrapper& self,
|
||||
py::args args, py::kwargs kwargs) {
|
||||
std::vector<PyObject*> result_raw(
|
||||
self.function_parameter_canonicalizer_.GetArgSize());
|
||||
|
||||
bool is_suceeded = self.function_parameter_canonicalizer_.Canonicalize(
|
||||
args.ptr(), kwargs.ptr(), absl::MakeSpan(result_raw));
|
||||
|
||||
if (!is_suceeded) {
|
||||
CHECK(PyErr_Occurred());
|
||||
throw py::error_already_set();
|
||||
}
|
||||
|
||||
py::list result;
|
||||
for (PyObject* obj : result_raw) result.append(obj);
|
||||
return result;
|
||||
});
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tensorflow::FunctionParameterCanonicalizer`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import _function_parameter_canonicalizer_binding_for_test
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionParameterCanonicalizerTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(FunctionParameterCanonicalizerTest, self).setUp()
|
||||
self._matmul_func = (
|
||||
_function_parameter_canonicalizer_binding_for_test
|
||||
.FunctionParameterCanonicalizer([
|
||||
'a', 'b', 'transpose_a', 'transpose_b', 'adjoint_a', 'adjoint_b',
|
||||
'a_is_sparse', 'b_is_sparse', 'name'
|
||||
], (False, False, False, False, False, False, None)))
|
||||
|
||||
def testPosOnly(self):
|
||||
self.assertEqual(
|
||||
self._matmul_func.canonicalize(2, 3),
|
||||
[2, 3, False, False, False, False, False, False, None])
|
||||
|
||||
def testPosOnly2(self):
|
||||
self.assertEqual(
|
||||
self._matmul_func.canonicalize(2, 3, True, False, True),
|
||||
[2, 3, True, False, True, False, False, False, None])
|
||||
|
||||
def testPosAndKwd(self):
|
||||
self.assertEqual(
|
||||
self._matmul_func.canonicalize(
|
||||
2, 3, transpose_a=True, name='my_matmul'),
|
||||
[2, 3, True, False, False, False, False, False, 'my_matmul'])
|
||||
|
||||
def testPosAndKwd2(self):
|
||||
self.assertEqual(
|
||||
self._matmul_func.canonicalize(2, b=3),
|
||||
[2, 3, False, False, False, False, False, False, None])
|
||||
|
||||
def testMissingPos(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Missing required positional argument'):
|
||||
self._matmul_func.canonicalize(2)
|
||||
|
||||
def testMissingPos2(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Missing required positional argument'):
|
||||
self._matmul_func.canonicalize(
|
||||
transpose_a=True, transpose_b=True, adjoint_a=True)
|
||||
|
||||
def testTooManyArgs(self):
|
||||
with self.assertRaisesRegex(TypeError, 'Too many arguments were given'):
|
||||
self._matmul_func.canonicalize(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
|
||||
|
||||
def testInvalidKwd(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'Got an unexpected keyword argument'):
|
||||
self._matmul_func.canonicalize(2, 3, hohoho=True)
|
||||
|
||||
def testDuplicatedArg(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Got multiple values for argument 'b'"):
|
||||
self._matmul_func.canonicalize(2, 3, False, b=4)
|
||||
|
||||
def testDuplicatedArg2(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "Got multiple values for argument 'transpose_a'"):
|
||||
self._matmul_func.canonicalize(2, 3, False, transpose_a=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user