Add a helper method for adding TFLite hashtable ops to op resolver.

Also added a python wrapper for TFLite hashtable ops.

PiperOrigin-RevId: 290565157
Change-Id: Ieb1be2c4c4129f1256599a22bbccba6a6fab8f69
This commit is contained in:
Jaesung Chung 2020-01-19 23:55:08 -08:00 committed by TensorFlower Gardener
parent dd116af48d
commit 03a2255949
6 changed files with 110 additions and 9 deletions

View File

@ -1,4 +1,5 @@
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
package(
default_visibility = [
@ -132,8 +133,12 @@ cc_library(
"hashtable.cc",
"hashtable_find.cc",
"hashtable_import.cc",
"hashtable_ops.cc",
"hashtable_size.cc",
],
hdrs = [
"hashtable_ops.h",
],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
@ -168,3 +173,14 @@ cc_test(
"@flatbuffers",
],
)
tf_py_wrap_cc(
name = "hashtable_ops_py_wrapper",
srcs = [
"hashtable_ops.i",
],
deps = [
":hashtable_op_kernels",
"//third_party/python_runtime:headers",
],
)

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
namespace tflite {
namespace ops {
namespace custom {
extern "C" void AddHashtableOps(::tflite::MutableOpResolver* resolver) {
// Add hashtable op handlers.
resolver->AddCustom("HashTableV2", tflite::ops::custom::Register_HASHTABLE());
resolver->AddCustom("LookupTableFindV2",
tflite::ops::custom::Register_HASHTABLE_FIND());
resolver->AddCustom("LookupTableImportV2",
tflite::ops::custom::Register_HASHTABLE_IMPORT());
resolver->AddCustom("LookupTableSizeV2",
tflite::ops::custom::Register_HASHTABLE_SIZE());
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_HASHTABLE_OPS_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_HASHTABLE_OPS_H_
#include "tensorflow/lite/mutable_op_resolver.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
extern "C" void AddHashtableOps(::tflite::MutableOpResolver* resolver);
} // namespace custom
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_KERNELS_HASHTABLE_OPS_H_

View File

@ -0,0 +1,20 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%{
#include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
%}
%include "tensorflow/lite/experimental/kernels/hashtable_ops.h"

View File

@ -222,6 +222,7 @@ cc_library(
"@com_google_absl//absl/strings",
"//tensorflow/lite:builtin_op_data",
"//tensorflow/lite:framework",
"//tensorflow/lite/experimental/kernels:hashtable_op_kernels",
"//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:custom_ops",

View File

@ -24,6 +24,7 @@ limitations under the License.
#if !defined(__APPLE__)
#include "tensorflow/lite/delegates/flex/delegate.h"
#endif
#include "tensorflow/lite/experimental/kernels/hashtable_ops.h"
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register_ref.h"
@ -322,15 +323,7 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
buildinop_resolver_->AddCustom("RFFT2D",
tflite::ops::custom::Register_RFFT2D());
buildinop_resolver_->AddCustom("HashTableV2",
tflite::ops::custom::Register_HASHTABLE());
buildinop_resolver_->AddCustom(
"LookupTableFindV2", tflite::ops::custom::Register_HASHTABLE_FIND());
buildinop_resolver_->AddCustom(
"LookupTableImportV2",
tflite::ops::custom::Register_HASHTABLE_IMPORT());
buildinop_resolver_->AddCustom(
"LookupTableSizeV2", tflite::ops::custom::Register_HASHTABLE_SIZE());
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
}
switch (delegate_type) {