Add pybind for hash table kernels
This change fixes #45195 PiperOrigin-RevId: 345549641 Change-Id: I7305d171061df50bfb0e27275f28b9647478dd05
This commit is contained in:
parent
5bd220564e
commit
077fe29d9d
tensorflow/lite
@ -1,3 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
@ -49,3 +51,21 @@ cc_test(
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "pywrap_hashtable_ops",
|
||||
srcs = [
|
||||
"hashtable_ops_wrapper.cc",
|
||||
],
|
||||
hdrs = ["hashtable_ops.h"],
|
||||
additional_exported_symbols = ["HashtableOpsRegisterer"],
|
||||
link_in_framework = True,
|
||||
module_name = "pywrap_hashtable_ops",
|
||||
deps = [
|
||||
":hashtable_op_kernels",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
34
tensorflow/lite/kernels/hashtable/hashtable_ops_wrapper.cc
Normal file
34
tensorflow/lite/kernels/hashtable/hashtable_ops_wrapper.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
||||
|
||||
PYBIND11_MODULE(pywrap_hashtable_ops, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
pywrap_hashtable_ops
|
||||
-----
|
||||
)pbdoc";
|
||||
m.def(
|
||||
"HashtableOpsRegisterer",
|
||||
[](uintptr_t resolver) {
|
||||
tflite::ops::custom::AddHashtableOps(
|
||||
reinterpret_cast<tflite::MutableOpResolver*>(resolver));
|
||||
},
|
||||
R"pbdoc(
|
||||
Hashtable op registerer function with the correct signature. Registers
|
||||
hashtable custom ops.
|
||||
)pbdoc");
|
||||
}
|
@ -216,7 +216,7 @@ py_test(
|
||||
":lite",
|
||||
":lite_v2_test_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
||||
"//tensorflow/lite/kernels/hashtable:pywrap_hashtable_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
|
@ -27,6 +27,7 @@ from six.moves import range
|
||||
from six.moves import zip
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_v2_test_util
|
||||
from tensorflow.lite.python.convert import mlir_quantize
|
||||
@ -799,7 +800,8 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = InterpreterWithCustomOps(
|
||||
model_content=tflite_model, custom_op_registerers=['AddHashtableOps'])
|
||||
model_content=tflite_model,
|
||||
custom_op_registerers=[hashtable_ops_registerer.HashtableOpsRegisterer])
|
||||
input_details = interpreter.get_input_details()
|
||||
output_details = interpreter.get_output_details()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user