Add pybind for hash table kernels

This change fixes 

PiperOrigin-RevId: 345549641
Change-Id: I7305d171061df50bfb0e27275f28b9647478dd05
This commit is contained in:
Jaesung Chung 2020-12-03 15:35:02 -08:00 committed by TensorFlower Gardener
parent 5bd220564e
commit 077fe29d9d
4 changed files with 58 additions and 2 deletions
tensorflow/lite

View File

@ -1,3 +1,5 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = [
"//visibility:public",
@ -49,3 +51,21 @@ cc_test(
"@flatbuffers",
],
)
pybind_extension(
name = "pywrap_hashtable_ops",
srcs = [
"hashtable_ops_wrapper.cc",
],
hdrs = ["hashtable_ops.h"],
additional_exported_symbols = ["HashtableOpsRegisterer"],
link_in_framework = True,
module_name = "pywrap_hashtable_ops",
deps = [
":hashtable_op_kernels",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
"//third_party/python_runtime:headers",
"@pybind11",
],
)

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
PYBIND11_MODULE(pywrap_hashtable_ops, m) {
m.doc() = R"pbdoc(
pywrap_hashtable_ops
-----
)pbdoc";
m.def(
"HashtableOpsRegisterer",
[](uintptr_t resolver) {
tflite::ops::custom::AddHashtableOps(
reinterpret_cast<tflite::MutableOpResolver*>(resolver));
},
R"pbdoc(
Hashtable op registerer function with the correct signature. Registers
hashtable custom ops.
)pbdoc");
}

View File

@ -216,7 +216,7 @@ py_test(
":lite",
":lite_v2_test_util",
"//tensorflow:tensorflow_py",
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
"//tensorflow/lite/kernels/hashtable:pywrap_hashtable_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"@six_archive//:six",

View File

@ -27,6 +27,7 @@ from six.moves import range
from six.moves import zip
import tensorflow as tf
from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_v2_test_util
from tensorflow.lite.python.convert import mlir_quantize
@ -799,7 +800,8 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
# Check values from converted model.
interpreter = InterpreterWithCustomOps(
model_content=tflite_model, custom_op_registerers=['AddHashtableOps'])
model_content=tflite_model,
custom_op_registerers=[hashtable_ops_registerer.HashtableOpsRegisterer])
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()