From 5b2ec7d103cb0d93c7271aba4a05c28db792dcc3 Mon Sep 17 00:00:00 2001
From: Kibeom Kim <kkb@google.com>
Date: Mon, 14 Sep 2020 09:17:52 -0700
Subject: [PATCH] Implement fast C++ Python function parameter canonicalizer
 util.

It flattens `args` and `kwargs` according to the specified function spec, for example:

def matmul(
    a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False,
    a_is_sparse=False, b_is_sparse=False, name=None
):

args = (2, 3, True)
kwargs = {'adjoint_a': True, 'name': 'my_matmul'}

Then using `FunctionParameterCanonicalizer.Canonicalize(...)`, users can get the following list of PyObject*.

[2, 3, True, False, True, False, False, False, 'my_matmul']

PiperOrigin-RevId: 331560836
Change-Id: Ib6380f81659909a2fe9422a25219b3f1cae85224
---
 tensorflow/python/BUILD                       |  38 +++++
 .../util/function_parameter_canonicalizer.cc  | 134 ++++++++++++++++++
 .../util/function_parameter_canonicalizer.h   |  73 ++++++++++
 ...arameter_canonicalizer_binding_for_test.cc |  71 ++++++++++
 .../function_parameter_canonicalizer_test.py  |  89 ++++++++++++
 5 files changed, 405 insertions(+)
 create mode 100644 tensorflow/python/util/function_parameter_canonicalizer.cc
 create mode 100644 tensorflow/python/util/function_parameter_canonicalizer.h
 create mode 100644 tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc
 create mode 100644 tensorflow/python/util/function_parameter_canonicalizer_test.py

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index d83d21907ce..395d281dccd 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -5532,6 +5532,44 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "function_parameter_canonicalizer",
+    srcs = ["util/function_parameter_canonicalizer.cc"],
+    hdrs = ["util/function_parameter_canonicalizer.h"],
+    deps = [
+        ":py_util",
+        ":safe_pyobject_ptr",
+        "//tensorflow/core/platform:logging",
+        "//tensorflow/core/platform:macros",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+tf_python_pybind_extension(
+    name = "_function_parameter_canonicalizer_binding_for_test",
+    testonly = True,
+    srcs = ["util/function_parameter_canonicalizer_binding_for_test.cc"],
+    module_name = "_function_parameter_canonicalizer_binding_for_test",
+    deps = [
+        ":function_parameter_canonicalizer",
+        "//third_party/python_runtime:headers",  # buildcleaner: keep
+        "@com_google_absl//absl/types:span",
+        "@pybind11",
+    ],
+)
+
+tf_py_test(
+    name = "function_parameter_canonicalizer_test",
+    srcs = ["util/function_parameter_canonicalizer_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":_function_parameter_canonicalizer_binding_for_test",
+        ":client_testlib",
+    ],
+)
+
 py_library(
     name = "global_test_configuration",
     deps = if_mlir(["//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration"]) +
diff --git a/tensorflow/python/util/function_parameter_canonicalizer.cc b/tensorflow/python/util/function_parameter_canonicalizer.cc
new file mode 100644
index 00000000000..3ae98ee0fc8
--- /dev/null
+++ b/tensorflow/python/util/function_parameter_canonicalizer.cc
@@ -0,0 +1,134 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/util/function_parameter_canonicalizer.h"
+
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/python/lib/core/py_util.h"
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
+
+namespace tensorflow {
+
+FunctionParameterCanonicalizer::FunctionParameterCanonicalizer(
+    absl::Span<const char*> arg_names, absl::Span<PyObject*> defaults)
+    : positional_args_size_(arg_names.size() - defaults.size()) {
+  DCheckPyGilState();
+  DCHECK_GE(positional_args_size_, 0);
+
+  interned_arg_names_.reserve(arg_names.size());
+  for (const char* obj : arg_names)
+    interned_arg_names_.emplace_back(PyUnicode_InternFromString(obj));
+
+  DCHECK(AreInternedArgNamesUnique());
+
+  for (PyObject* obj : defaults) Py_INCREF(obj);
+  defaults_ = std::vector<Safe_PyObjectPtr>(defaults.begin(), defaults.end());
+}
+
+bool FunctionParameterCanonicalizer::Canonicalize(
+    PyObject* args, PyObject* kwargs, absl::Span<PyObject*> result) {
+  // TODO(kkb): Closely follow `Python/ceval.c`'s logic and error handling.
+
+  DCheckPyGilState();
+  DCHECK(PyTuple_CheckExact(args));
+  DCHECK(PyDict_CheckExact(kwargs));
+  DCHECK_EQ(result.size(), interned_arg_names_.size());
+
+  const int args_size = Py_SIZE(args);
+  int remaining_positional_args_count = positional_args_size_ - args_size;
+
+  // Check if the number of input arguments are too many.
+  if (TF_PREDICT_FALSE(args_size > interned_arg_names_.size())) {
+    // TODO(kkb): Also report the actual numbers.
+    PyErr_SetString(PyExc_TypeError, "Too many arguments were given");
+    return false;
+  }
+
+  // Fill positional arguments.
+  for (int i = 0; i < args_size; ++i) result[i] = PyTuple_GET_ITEM(args, i);
+
+  // Fill default arguments.
+  for (int i = args_size; i < interned_arg_names_.size(); ++i)
+    result[i] = defaults_[i - positional_args_size_].get();
+
+  // Fill keyword arguments.
+  if (kwargs != nullptr) {
+    PyObject *key, *value;
+    Py_ssize_t pos = 0;
+    while (PyDict_Next(kwargs, &pos, &key, &value)) {
+      std::size_t index = InternedArgNameLinearSearch(key);
+
+      // Check if key object(argument name) was found in the pre-built intern
+      // string table.
+      if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
+        // `key` might not be an interend string, so get the interned string
+        // and try again.
+        PyUnicode_InternInPlace(&key);
+        index = InternedArgNameLinearSearch(key);
+
+        // Stil not found, then return an error.
+        if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
+          PyErr_Format(PyExc_TypeError,
+                       "Got an unexpected keyword argument '%S'", key);
+          return false;
+        }
+      }
+
+      // Check if the keyword argument overlaps with positional arguments.
+      if (TF_PREDICT_FALSE(index < args_size)) {
+        PyErr_Format(PyExc_TypeError, "Got multiple values for argument '%S'",
+                     key);
+        return false;
+      }
+
+      if (TF_PREDICT_FALSE(index < positional_args_size_))
+        --remaining_positional_args_count;
+
+      result[index] = value;
+    }
+  }
+
+  // Check if all the arguments are filled.
+  // Example failure, not enough number of arguments passed: `matmul(x)`
+  if (TF_PREDICT_FALSE(remaining_positional_args_count > 0)) {
+    // TODO(kkb): Report what arguments are missing.
+    PyErr_SetString(PyExc_TypeError, "Missing required positional argument");
+    return false;
+  }
+
+  return true;
+}
+
+ABSL_MUST_USE_RESULT
+ABSL_ATTRIBUTE_HOT
+inline std::size_t FunctionParameterCanonicalizer::InternedArgNameLinearSearch(
+    PyObject* name) {
+  std::size_t result = interned_arg_names_.size();
+
+  for (std::size_t i = 0; i < interned_arg_names_.size(); ++i)
+    if (TF_PREDICT_FALSE(name == interned_arg_names_[i].get())) return i;
+
+  return result;
+}
+
+bool FunctionParameterCanonicalizer::AreInternedArgNamesUnique() {
+  absl::flat_hash_set<PyObject*> interned_arg_names_set;
+  for (const Safe_PyObjectPtr& obj : interned_arg_names_)
+    interned_arg_names_set.emplace(obj.get());
+
+  return interned_arg_names_set.size() == interned_arg_names_.size();
+}
+}  // namespace tensorflow
diff --git a/tensorflow/python/util/function_parameter_canonicalizer.h b/tensorflow/python/util/function_parameter_canonicalizer.h
new file mode 100644
index 00000000000..8e7fd7dd693
--- /dev/null
+++ b/tensorflow/python/util/function_parameter_canonicalizer.h
@@ -0,0 +1,73 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_UTIL_FUNCTION_PARAMETER_CANONICALIZER_H_
+#define TENSORFLOW_PYTHON_UTIL_FUNCTION_PARAMETER_CANONICALIZER_H_
+
+#include <Python.h>
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
+
+namespace tensorflow {
+
+// A class that Canonicalizes Python arg & kwargs parameters.
+class FunctionParameterCanonicalizer {
+ public:
+  // `arg_names` is a list of argument names, and `defaults` is default PyObject
+  // instances for arguments. `default` is aligned to the end.
+  FunctionParameterCanonicalizer(absl::Span<const char*> arg_names,
+                                 absl::Span<PyObject*> defaults);
+
+  // Returns the total number of arguments.
+  ABSL_MUST_USE_RESULT
+  int GetArgSize() const { return interned_arg_names_.size(); }
+
+  // Canonicalizes `args` and `kwargs` by the spec specified at construction.
+  // It's written to `result`. Returns `true` if Canonicalization was
+  // successful, and `false` otherwise. When it fails, it also sets CPython
+  // error status.
+  // This function does not update reference counter of any Python objects.
+  // `PyObject*`s in `result` are borrowed references from `args`, `kwargs`, and
+  // possibly `defaults_`, and will be only valid if `args` and `kwargs` are
+  // still alive.
+  ABSL_MUST_USE_RESULT
+  ABSL_ATTRIBUTE_HOT
+  bool Canonicalize(PyObject* args, PyObject* kwargs,
+                    absl::Span<PyObject*> result);
+
+ private:
+  // Simple linear search of `name` in `interned_arg_names`. If found, returns
+  // the index. If not found, returns `interned_arg_names.size()`.
+  ABSL_MUST_USE_RESULT
+  ABSL_ATTRIBUTE_HOT
+  std::size_t InternedArgNameLinearSearch(PyObject* name);
+
+  // Check if `interned_arg_names_` is unique.
+  bool AreInternedArgNamesUnique();
+
+  // TODO(kkb): Use one `std::vector` and two `absl:Span`s instead to improve
+  // cache locality.
+  std::vector<Safe_PyObjectPtr> interned_arg_names_;
+  std::vector<Safe_PyObjectPtr> defaults_;
+  const int positional_args_size_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_PYTHON_UTIL_FUNCTION_PARAMETER_CANONICALIZER_H_
diff --git a/tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc b/tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc
new file mode 100644
index 00000000000..e93f6905734
--- /dev/null
+++ b/tensorflow/python/util/function_parameter_canonicalizer_binding_for_test.cc
@@ -0,0 +1,71 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <Python.h>
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "pybind11/stl.h"
+#include "tensorflow/python/util/function_parameter_canonicalizer.h"
+
+namespace py = pybind11;
+
+class FunctionParameterCanonicalizerWrapper {
+ public:
+  FunctionParameterCanonicalizerWrapper(absl::Span<const char*> arg_names,
+                                        absl::Span<PyObject*> defaults)
+      : function_parameter_canonicalizer_(arg_names, defaults) {}
+
+  tensorflow::FunctionParameterCanonicalizer function_parameter_canonicalizer_;
+};
+
+PYBIND11_MODULE(_function_parameter_canonicalizer_binding_for_test, m) {
+  py::class_<FunctionParameterCanonicalizerWrapper>(
+      m, "FunctionParameterCanonicalizer")
+      .def(py::init([](std::vector<std::string> arg_names, py::tuple defaults) {
+        std::vector<const char*> arg_names_c_str;
+        for (const std::string& name : arg_names)
+          arg_names_c_str.emplace_back(name.c_str());
+
+        tensorflow::Safe_PyObjectPtr defaults_fast(
+            PySequence_Fast(defaults.ptr(), "Expected tuple"));
+        if (!defaults) throw py::error_already_set();
+        PyObject** default_items = PySequence_Fast_ITEMS(defaults_fast.get());
+        return new FunctionParameterCanonicalizerWrapper(
+            absl::MakeSpan(arg_names_c_str),
+            absl::MakeSpan(default_items,
+                           PySequence_Fast_GET_SIZE(defaults_fast.get())));
+      }))
+      .def("canonicalize", [](FunctionParameterCanonicalizerWrapper& self,
+                              py::args args, py::kwargs kwargs) {
+        std::vector<PyObject*> result_raw(
+            self.function_parameter_canonicalizer_.GetArgSize());
+
+        bool is_suceeded = self.function_parameter_canonicalizer_.Canonicalize(
+            args.ptr(), kwargs.ptr(), absl::MakeSpan(result_raw));
+
+        if (!is_suceeded) {
+          CHECK(PyErr_Occurred());
+          throw py::error_already_set();
+        }
+
+        py::list result;
+        for (PyObject* obj : result_raw) result.append(obj);
+        return result;
+      });
+}
diff --git a/tensorflow/python/util/function_parameter_canonicalizer_test.py b/tensorflow/python/util/function_parameter_canonicalizer_test.py
new file mode 100644
index 00000000000..968265ff36f
--- /dev/null
+++ b/tensorflow/python/util/function_parameter_canonicalizer_test.py
@@ -0,0 +1,89 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tensorflow::FunctionParameterCanonicalizer`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import _function_parameter_canonicalizer_binding_for_test
+from tensorflow.python.platform import test
+
+
+class FunctionParameterCanonicalizerTest(test.TestCase):
+
+  def setUp(self):
+    super(FunctionParameterCanonicalizerTest, self).setUp()
+    self._matmul_func = (
+        _function_parameter_canonicalizer_binding_for_test
+        .FunctionParameterCanonicalizer([
+            'a', 'b', 'transpose_a', 'transpose_b', 'adjoint_a', 'adjoint_b',
+            'a_is_sparse', 'b_is_sparse', 'name'
+        ], (False, False, False, False, False, False, None)))
+
+  def testPosOnly(self):
+    self.assertEqual(
+        self._matmul_func.canonicalize(2, 3),
+        [2, 3, False, False, False, False, False, False, None])
+
+  def testPosOnly2(self):
+    self.assertEqual(
+        self._matmul_func.canonicalize(2, 3, True, False, True),
+        [2, 3, True, False, True, False, False, False, None])
+
+  def testPosAndKwd(self):
+    self.assertEqual(
+        self._matmul_func.canonicalize(
+            2, 3, transpose_a=True, name='my_matmul'),
+        [2, 3, True, False, False, False, False, False, 'my_matmul'])
+
+  def testPosAndKwd2(self):
+    self.assertEqual(
+        self._matmul_func.canonicalize(2, b=3),
+        [2, 3, False, False, False, False, False, False, None])
+
+  def testMissingPos(self):
+    with self.assertRaisesRegex(TypeError,
+                                'Missing required positional argument'):
+      self._matmul_func.canonicalize(2)
+
+  def testMissingPos2(self):
+    with self.assertRaisesRegex(TypeError,
+                                'Missing required positional argument'):
+      self._matmul_func.canonicalize(
+          transpose_a=True, transpose_b=True, adjoint_a=True)
+
+  def testTooManyArgs(self):
+    with self.assertRaisesRegex(TypeError, 'Too many arguments were given'):
+      self._matmul_func.canonicalize(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+
+  def testInvalidKwd(self):
+    with self.assertRaisesRegex(TypeError,
+                                'Got an unexpected keyword argument'):
+      self._matmul_func.canonicalize(2, 3, hohoho=True)
+
+  def testDuplicatedArg(self):
+    with self.assertRaisesRegex(TypeError,
+                                "Got multiple values for argument 'b'"):
+      self._matmul_func.canonicalize(2, 3, False, b=4)
+
+  def testDuplicatedArg2(self):
+    with self.assertRaisesRegex(
+        TypeError, "Got multiple values for argument 'transpose_a'"):
+      self._matmul_func.canonicalize(2, 3, False, transpose_a=True)
+
+
+if __name__ == '__main__':
+  test.main()