Make Hashtable ops in TFLite compatible with MLIR converter
PiperOrigin-RevId: 288628649 Change-Id: Ie7dbd06e1e277e33da5d08993d27d08920837f3a
This commit is contained in:
parent
7c965ff357
commit
57b4fbbfdb
tensorflow/lite
experimental
kernels
resource
kernels
testing
@ -130,13 +130,14 @@ cc_library(
|
||||
name = "hashtable_op_kernels",
|
||||
srcs = [
|
||||
"hashtable.cc",
|
||||
"hashtable_find.cc",
|
||||
"hashtable_import.cc",
|
||||
"hashtable_lookup.cc",
|
||||
"hashtable_size.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/experimental/resource",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
@ -26,7 +28,10 @@ namespace ops {
|
||||
namespace custom {
|
||||
namespace hashtable {
|
||||
|
||||
constexpr int kResourceHandleTensor = 0;
|
||||
static constexpr int kResourceHandleTensor = 0;
|
||||
static constexpr const char kSharedNameStr[] = "shared_name";
|
||||
static constexpr const char kKeyDtypeStr[] = "key_dtype";
|
||||
static constexpr const char kValueDtypeStr[] = "value_dtype";
|
||||
|
||||
// TODO(b/144728911): The following structure should be moved to
|
||||
// builtin_op_data.h when it is ready to become a builtin op.
|
||||
@ -41,11 +46,18 @@ void* InitHashtable(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
|
||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
||||
const std::string table_name = m[kSharedNameStr].AsString().str();
|
||||
|
||||
TfLiteType key_dtype, value_dtype;
|
||||
ConvertTensorType(static_cast<TensorType>(m[kKeyDtypeStr].AsInt32()),
|
||||
&key_dtype, nullptr);
|
||||
ConvertTensorType(static_cast<TensorType>(m[kValueDtypeStr].AsInt32()),
|
||||
&value_dtype, nullptr);
|
||||
|
||||
TfLiteHashtableParams* option = new TfLiteHashtableParams;
|
||||
option->table_name = m["table_name"].AsString().str();
|
||||
option->key_dtype = static_cast<TfLiteType>(m["key_dtype"].AsInt32());
|
||||
option->value_dtype = static_cast<TfLiteType>(m["value_dtype"].AsInt32());
|
||||
option->table_name = table_name;
|
||||
option->key_dtype = key_dtype;
|
||||
option->value_dtype = value_dtype;
|
||||
|
||||
return option;
|
||||
}
|
||||
@ -61,12 +73,12 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
const auto* params =
|
||||
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE(context, !params->table_name.empty());
|
||||
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt32 ||
|
||||
params->key_dtype == kTfLiteString));
|
||||
TF_LITE_ENSURE(context, (params->value_dtype == kTfLiteInt32 ||
|
||||
params->value_dtype == kTfLiteString ||
|
||||
params->value_dtype == kTfLiteFloat32));
|
||||
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt64 &&
|
||||
params->value_dtype == kTfLiteString) ||
|
||||
(params->key_dtype == kTfLiteString &&
|
||||
params->value_dtype == kTfLiteInt64));
|
||||
|
||||
TfLiteTensor* resource_handle_tensor =
|
||||
GetOutput(context, node, kResourceHandleTensor);
|
||||
@ -78,6 +90,7 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||
const auto* params =
|
||||
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
|
||||
|
||||
@ -100,12 +113,9 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace hashtable
|
||||
|
||||
TfLiteRegistration* Register_HASHTABLE() {
|
||||
static TfLiteRegistration r = {hashtable::InitHashtable,
|
||||
hashtable::FreeHashtable,
|
||||
hashtable::PrepareHashtable,
|
||||
hashtable::EvalHashtable,
|
||||
nullptr,
|
||||
BuiltinOperator_CUSTOM};
|
||||
static TfLiteRegistration r = {
|
||||
hashtable::InitHashtable, hashtable::FreeHashtable,
|
||||
hashtable::PrepareHashtable, hashtable::EvalHashtable};
|
||||
return &r;
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ constexpr int kKeyTensor = 1;
|
||||
constexpr int kDefaultValueTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
@ -42,26 +42,19 @@ TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
const TfLiteTensor* default_value_tensor =
|
||||
GetInput(context, node, kDefaultValueTensor);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(default_value_tensor), 1);
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(default_value_tensor, 0), 1);
|
||||
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
|
||||
TF_LITE_ENSURE(context, (output_tensor->type == kTfLiteInt32 ||
|
||||
output_tensor->type == kTfLiteString ||
|
||||
output_tensor->type == kTfLiteFloat32));
|
||||
|
||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
|
||||
key_tensor->type == kTfLiteString));
|
||||
if (output_tensor->type != kTfLiteString) {
|
||||
return context->ResizeTensor(context, output_tensor,
|
||||
TfLiteIntArrayCopy(key_tensor->dims));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
|
||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
|
||||
output_tensor->type == kTfLiteString) ||
|
||||
(key_tensor->type == kTfLiteString &&
|
||||
output_tensor->type == kTfLiteInt64));
|
||||
return context->ResizeTensor(context, output_tensor,
|
||||
TfLiteIntArrayCopy(key_tensor->dims));
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus EvalHashtableFind(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_resource_id_tensor =
|
||||
GetInput(context, node, kInputResourceIdTensor);
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
@ -77,19 +70,18 @@ TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, lookup != nullptr);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
lookup->CheckKeyAndValueTypes(context, key_tensor, output_tensor));
|
||||
return lookup->Lookup(context, key_tensor, output_tensor,
|
||||
default_value_tensor);
|
||||
auto result =
|
||||
lookup->Lookup(context, key_tensor, output_tensor, default_value_tensor);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace hashtable
|
||||
|
||||
TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
|
||||
TfLiteRegistration* Register_HASHTABLE_FIND() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
hashtable::PrepareHashtableLookup,
|
||||
hashtable::EvalHashtableLookup,
|
||||
nullptr,
|
||||
BuiltinOperator_CUSTOM};
|
||||
hashtable::PrepareHashtableFind,
|
||||
hashtable::EvalHashtableFind};
|
||||
return &r;
|
||||
}
|
||||
|
@ -40,13 +40,11 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
|
||||
|
||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
|
||||
key_tensor->type == kTfLiteString));
|
||||
|
||||
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
|
||||
TF_LITE_ENSURE(context, (value_tensor->type == kTfLiteInt32 ||
|
||||
value_tensor->type == kTfLiteString ||
|
||||
value_tensor->type == kTfLiteFloat32));
|
||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
|
||||
value_tensor->type == kTfLiteString) ||
|
||||
(key_tensor->type == kTfLiteString &&
|
||||
value_tensor->type == kTfLiteInt64));
|
||||
// TODO(b/144731295): Tensorflow lookup ops support 1-D vector in storing
|
||||
// values.
|
||||
TF_LITE_ENSURE(context, HaveSameShapes(key_tensor, value_tensor));
|
||||
@ -69,7 +67,8 @@ TfLiteStatus EvalHashtableImport(TfLiteContext* context, TfLiteNode* node) {
|
||||
lookup->CheckKeyAndValueTypes(context, key_tensor, value_tensor));
|
||||
// The hashtable resource will only be initialized once, attempting to
|
||||
// initialize it multiple times will be a no-op.
|
||||
return lookup->Import(context, key_tensor, value_tensor);
|
||||
auto result = lookup->Import(context, key_tensor, value_tensor);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace hashtable
|
||||
@ -78,9 +77,7 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
hashtable::PrepareHashtableImport,
|
||||
hashtable::EvalHashtableImport,
|
||||
nullptr,
|
||||
BuiltinOperator_CUSTOM};
|
||||
hashtable::EvalHashtableImport};
|
||||
return &r;
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_HASHTABLE();
|
||||
TfLiteRegistration* Register_HASHTABLE_LOOKUP();
|
||||
TfLiteRegistration* Register_HASHTABLE_FIND();
|
||||
TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
||||
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
||||
|
||||
@ -45,6 +45,10 @@ namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
static constexpr const char kSharedNameStr[] = "shared_name";
|
||||
static constexpr const char kKeyDtypeStr[] = "key_dtype";
|
||||
static constexpr const char kValueDtypeStr[] = "value_dtype";
|
||||
|
||||
typedef enum {
|
||||
kResourceTensorId = 0,
|
||||
kKeyTensorId = 1,
|
||||
@ -84,6 +88,19 @@ void SetTensorData(Interpreter* interpreter, int tensorId,
|
||||
buf.WriteToTensorAsVector(tensor);
|
||||
}
|
||||
|
||||
TensorType ConvertTfLiteType(TfLiteType type) {
|
||||
// Currently, hashtable kernels support INT64 and STRING types only.
|
||||
switch (type) {
|
||||
case kTfLiteInt64:
|
||||
return TensorType_INT64;
|
||||
case kTfLiteString:
|
||||
return TensorType_STRING;
|
||||
default:
|
||||
CHECK(false); // Not reached.
|
||||
return TensorType_MIN;
|
||||
}
|
||||
}
|
||||
|
||||
// HashtableGraph generates a graph with hash table ops. This class can create
|
||||
// the following scenarios:
|
||||
//
|
||||
@ -120,7 +137,7 @@ class HashtableGraph {
|
||||
// Hash table lookup node.
|
||||
interpreter_->AddNodeWithParameters(
|
||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||
&node_index);
|
||||
|
||||
// Hash table size node.
|
||||
@ -142,7 +159,7 @@ class HashtableGraph {
|
||||
// Hash table lookup node.
|
||||
interpreter_->AddNodeWithParameters(
|
||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||
&node_index);
|
||||
|
||||
// Hash table size node.
|
||||
@ -174,7 +191,7 @@ class HashtableGraph {
|
||||
// Hash table lookup node.
|
||||
interpreter_->AddNodeWithParameters(
|
||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||
&node_index);
|
||||
|
||||
// Hash table size node.
|
||||
@ -201,7 +218,7 @@ class HashtableGraph {
|
||||
// Hash table lookup node.
|
||||
interpreter_->AddNodeWithParameters(
|
||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||
&node_index);
|
||||
|
||||
// Hash table size node.
|
||||
@ -226,8 +243,8 @@ class HashtableGraph {
|
||||
// Hash table two lookup node.
|
||||
interpreter_->AddNodeWithParameters(
|
||||
{kResourceTwoTensorId, kQueryTwoTensorId, kDefaultValueTwoTensorId},
|
||||
{kResultTwoTensorId}, nullptr, 0, nullptr,
|
||||
hashtable_lookup_registration_, &node_index);
|
||||
{kResultTwoTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||
&node_index);
|
||||
|
||||
// Hash table two size node.
|
||||
interpreter_->AddNodeWithParameters(
|
||||
@ -261,16 +278,16 @@ class HashtableGraph {
|
||||
default_value_two_ = default_value;
|
||||
}
|
||||
|
||||
int GetTableSize() {
|
||||
int64_t GetTableSize() {
|
||||
auto* size_tensor = interpreter_->tensor(kSizeTensorId);
|
||||
auto size_tensor_shape = GetTensorShape(size_tensor);
|
||||
return GetTensorData<int>(size_tensor)[0];
|
||||
return GetTensorData<int64_t>(size_tensor)[0];
|
||||
}
|
||||
|
||||
int GetTableTwoSize() {
|
||||
int64_t GetTableTwoSize() {
|
||||
auto* size_tensor = interpreter_->tensor(kSizeTwoTensorId);
|
||||
auto size_tensor_shape = GetTensorShape(size_tensor);
|
||||
return GetTensorData<int>(size_tensor)[0];
|
||||
return GetTensorData<int64_t>(size_tensor)[0];
|
||||
}
|
||||
|
||||
std::vector<ValueType> GetLookupResult() {
|
||||
@ -363,7 +380,7 @@ class HashtableGraph {
|
||||
TfLiteQuantization());
|
||||
|
||||
// Result tensor for size calculation.
|
||||
interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt32, "",
|
||||
interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt64, "",
|
||||
{1}, TfLiteQuantization());
|
||||
|
||||
// Default value tensor for lookup.
|
||||
@ -396,7 +413,7 @@ class HashtableGraph {
|
||||
{static_cast<int>(queries_two_.size())}, TfLiteQuantization());
|
||||
|
||||
// Result tensor for size calculation.
|
||||
interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt32,
|
||||
interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt64,
|
||||
"", {1}, TfLiteQuantization());
|
||||
|
||||
// Default value tensor for lookup.
|
||||
@ -433,9 +450,9 @@ class HashtableGraph {
|
||||
hashtable_registration_ = tflite::ops::custom::Register_HASHTABLE();
|
||||
ASSERT_NE(hashtable_registration_, nullptr);
|
||||
|
||||
hashtable_lookup_registration_ =
|
||||
tflite::ops::custom::Register_HASHTABLE_LOOKUP();
|
||||
ASSERT_NE(hashtable_lookup_registration_, nullptr);
|
||||
hashtable_find_registration_ =
|
||||
tflite::ops::custom::Register_HASHTABLE_FIND();
|
||||
ASSERT_NE(hashtable_find_registration_, nullptr);
|
||||
|
||||
hashtable_import_registration_ =
|
||||
tflite::ops::custom::Register_HASHTABLE_IMPORT();
|
||||
@ -447,11 +464,15 @@ class HashtableGraph {
|
||||
}
|
||||
|
||||
std::vector<uint8_t> GetHashtableParamsInFlatbuffer() {
|
||||
TensorType key_tensor_type = ConvertTfLiteType(key_type_);
|
||||
TensorType value_tensor_type = ConvertTfLiteType(value_type_);
|
||||
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Map([&]() {
|
||||
fbb.String("table_name", "test_table_name" + std::to_string(std::rand()));
|
||||
fbb.Int("key_dtype", key_type_);
|
||||
fbb.Int("value_dtype", value_type_);
|
||||
fbb.String(kSharedNameStr,
|
||||
"test_table_name" + std::to_string(std::rand()));
|
||||
fbb.Int(kKeyDtypeStr, key_tensor_type);
|
||||
fbb.Int(kValueDtypeStr, value_tensor_type);
|
||||
});
|
||||
fbb.Finish();
|
||||
return fbb.GetBuffer();
|
||||
@ -475,7 +496,7 @@ class HashtableGraph {
|
||||
|
||||
// Op registrations.
|
||||
TfLiteRegistration* hashtable_registration_;
|
||||
TfLiteRegistration* hashtable_lookup_registration_;
|
||||
TfLiteRegistration* hashtable_find_registration_;
|
||||
TfLiteRegistration* hashtable_import_registration_;
|
||||
TfLiteRegistration* hashtable_size_registration_;
|
||||
|
||||
@ -539,64 +560,27 @@ class HashtableDefaultGraphTest {
|
||||
std::vector<ValueType> lookup_result_;
|
||||
};
|
||||
|
||||
TEST(HashtableOpsTest, TestInt32ToInt32Hashtable) {
|
||||
HashtableDefaultGraphTest<int, int> t(
|
||||
kTfLiteInt32, kTfLiteInt32,
|
||||
/*keys=*/{1, 2, 3}, /*values=*/{4, 5, 6}, /*queries=*/{2, 3, 4},
|
||||
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
|
||||
t.InvokeAndVerifyIntResult();
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestInt32ToFloat32Hashtable) {
|
||||
HashtableDefaultGraphTest<int, float> t(
|
||||
kTfLiteInt32, kTfLiteFloat32,
|
||||
/*keys=*/{1, 2, 3}, /*values=*/{4.0f, 5.0f, 6.0f}, /*queries=*/{2, 3, 4},
|
||||
/*default_value=*/-1.0f, /*table_size=*/3,
|
||||
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
|
||||
t.InvokeAndVerifyFloatResult();
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestInt32ToStringHashtable) {
|
||||
HashtableDefaultGraphTest<int, std::string> t(
|
||||
kTfLiteInt32, kTfLiteString,
|
||||
TEST(HashtableOpsTest, TestInt64ToStringHashtable) {
|
||||
HashtableDefaultGraphTest<std::int64_t, std::string> t(
|
||||
kTfLiteInt64, kTfLiteString,
|
||||
/*keys=*/{1, 2, 3}, /*values=*/{"a", "b", "c"}, /*queries=*/{2, 3, 4},
|
||||
/*default_value=*/"d", /*table_size=*/3,
|
||||
/*lookup_result=*/{"b", "c", "d"});
|
||||
t.InvokeAndVerifyStringResult();
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestStringToInt32Hashtable) {
|
||||
HashtableDefaultGraphTest<std::string, int> t(
|
||||
kTfLiteString, kTfLiteInt32,
|
||||
TEST(HashtableOpsTest, TestStringToInt64Hashtable) {
|
||||
HashtableDefaultGraphTest<std::string, int64_t> t(
|
||||
kTfLiteString, kTfLiteInt64,
|
||||
/*keys=*/{"A", "B", "C"}, /*values=*/{4, 5, 6},
|
||||
/*queries=*/{"B", "C", "D"},
|
||||
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
|
||||
t.InvokeAndVerifyIntResult();
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestStringToFloat32Hashtable) {
|
||||
HashtableDefaultGraphTest<std::string, float> t(
|
||||
kTfLiteString, kTfLiteFloat32,
|
||||
/*keys=*/{"A", "B", "C"}, /*values=*/{4.0f, 5.0f, 6.0f},
|
||||
/*queries=*/{"B", "C", "D"},
|
||||
/*default_value=*/-1.0f, /*table_size=*/3,
|
||||
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
|
||||
t.InvokeAndVerifyFloatResult();
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestStringToStringHashtable) {
|
||||
HashtableDefaultGraphTest<std::string, std::string> t(
|
||||
kTfLiteString, kTfLiteString,
|
||||
/*keys=*/{"A", "B", "C"}, /*values=*/{"a", "b", "c"},
|
||||
/*queries=*/{"B", "C", "D"},
|
||||
/*default_value=*/"d", /*table_size=*/3,
|
||||
/*lookup_result=*/{"b", "c", "d"});
|
||||
t.InvokeAndVerifyStringResult();
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestNoImport) {
|
||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
||||
graph.SetQuery({1, 2, 3}, -1);
|
||||
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||
graph.SetQuery({"1", "2", "3"}, -1);
|
||||
graph.AddTensors();
|
||||
graph.BuildNoImportGraph();
|
||||
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
|
||||
@ -607,9 +591,9 @@ TEST(HashtableOpsTest, TestNoImport) {
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestImportTwice) {
|
||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
||||
graph.SetTable({1, 2, 3}, {4, 5, 6});
|
||||
graph.SetQuery({2, 3, 4}, -1);
|
||||
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||
graph.SetTable({"1", "2", "3"}, {4, 5, 6});
|
||||
graph.SetQuery({"2", "3", "4"}, -1);
|
||||
graph.AddTensors();
|
||||
graph.BuildImportTwiceGraph();
|
||||
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
|
||||
@ -621,11 +605,11 @@ TEST(HashtableOpsTest, TestImportTwice) {
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestTwoHashtables) {
|
||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
||||
graph.SetTable({1, 2, 3}, {4, 5, 6});
|
||||
graph.SetQuery({2, 3, 4}, -1);
|
||||
graph.SetTableTwo({-1, -2, -3}, {7, 8, 9});
|
||||
graph.SetQueryForTableTwo({-4, -2, -3}, -2);
|
||||
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||
graph.SetTable({"1", "2", "3"}, {4, 5, 6});
|
||||
graph.SetQuery({"2", "3", "4"}, -1);
|
||||
graph.SetTableTwo({"-1", "-2", "-3"}, {7, 8, 9});
|
||||
graph.SetQueryForTableTwo({"-4", "-2", "-3"}, -2);
|
||||
graph.AddTensors(/*table_two_initialization=*/true);
|
||||
graph.BuildTwoHashtablesGraph();
|
||||
EXPECT_EQ(graph.AllocateTensors(/*table_two_initialization=*/true),
|
||||
@ -639,9 +623,9 @@ TEST(HashtableOpsTest, TestTwoHashtables) {
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
|
||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
||||
graph.SetTable({1, 2, 3}, {4, 5});
|
||||
graph.SetQuery({2, 3, 4}, -1);
|
||||
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||
graph.SetTable({"1", "2", "3"}, {4, 5});
|
||||
graph.SetQuery({"2", "3", "4"}, -1);
|
||||
graph.AddTensors();
|
||||
graph.BuildDefaultGraph();
|
||||
EXPECT_EQ(graph.AllocateTensors(), kTfLiteError);
|
||||
@ -650,16 +634,16 @@ TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
|
||||
// HashtableOpModel creates a model with one signle Hashtable op.
|
||||
class HashtableOpModel : public SingleOpModel {
|
||||
public:
|
||||
explicit HashtableOpModel(const char* table_name, TfLiteType key_dtype,
|
||||
TfLiteType value_dtype) {
|
||||
explicit HashtableOpModel(const char* table_name, TensorType key_dtype,
|
||||
TensorType value_dtype) {
|
||||
output_ = AddOutput(GetTensorType<int>());
|
||||
|
||||
// Set up and pass in custom options using flexbuffer.
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Map([&]() {
|
||||
fbb.String("table_name", std::string(table_name));
|
||||
fbb.Int("key_dtype", key_dtype);
|
||||
fbb.Int("value_dtype", value_dtype);
|
||||
fbb.String(kSharedNameStr, std::string(table_name));
|
||||
fbb.Int(kKeyDtypeStr, key_dtype);
|
||||
fbb.Int(kValueDtypeStr, value_dtype);
|
||||
});
|
||||
fbb.Finish();
|
||||
SetCustomOp("HASHTABLE", fbb.GetBuffer(),
|
||||
@ -679,7 +663,7 @@ class HashtableOpModel : public SingleOpModel {
|
||||
};
|
||||
|
||||
TEST(HashtableOpsTest, TestHashtable) {
|
||||
HashtableOpModel m("test_hashtable", kTfLiteInt32, kTfLiteString);
|
||||
HashtableOpModel m("test_hashtable", TensorType_INT64, TensorType_STRING);
|
||||
EXPECT_EQ(m.GetResources().size(), 0);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||
@ -689,12 +673,12 @@ TEST(HashtableOpsTest, TestHashtable) {
|
||||
EXPECT_NE(resource_id, 0);
|
||||
auto* hashtable = resource::GetHashtableResource(&resources, resource_id);
|
||||
EXPECT_TRUE(hashtable != nullptr);
|
||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
|
||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
|
||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) {
|
||||
TfLiteTensor CreateTensor(TfLiteType type, const std::vector<T>& vec) {
|
||||
TfLiteTensor tensor = {};
|
||||
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
||||
dims->data[0] = vec.size();
|
||||
@ -715,6 +699,28 @@ TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <>
|
||||
TfLiteTensor CreateTensor(TfLiteType type,
|
||||
const std::vector<std::string>& vec) {
|
||||
TfLiteTensor tensor = {};
|
||||
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
||||
dims->data[0] = vec.size();
|
||||
tensor.dims = dims;
|
||||
tensor.name = "";
|
||||
tensor.params = {};
|
||||
tensor.quantization = {kTfLiteNoQuantization, nullptr};
|
||||
tensor.is_variable = false;
|
||||
tensor.allocation_type = kTfLiteDynamic;
|
||||
tensor.allocation = nullptr;
|
||||
tensor.type = type;
|
||||
DynamicBuffer buf;
|
||||
for (std::string str : vec) {
|
||||
buf.AddString(str.c_str(), str.size());
|
||||
}
|
||||
buf.WriteToTensor(&tensor, nullptr);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename KeyType, typename ValueType>
|
||||
void InitHashtableResource(resource::ResourceMap* resources, int resource_id,
|
||||
TfLiteType key_type, TfLiteType value_type,
|
||||
@ -772,12 +778,12 @@ class BaseHashtableOpModel : public SingleOpModel {
|
||||
TensorType value_type_;
|
||||
};
|
||||
|
||||
// HashtableLookupOpModel creates a model with a HashtableLookup op.
|
||||
// HashtableFindOpModel creates a model with a HashtableLookup op.
|
||||
template <typename KeyType, typename ValueType>
|
||||
class HashtableLookupOpModel : public BaseHashtableOpModel {
|
||||
class HashtableFindOpModel : public BaseHashtableOpModel {
|
||||
public:
|
||||
HashtableLookupOpModel(const TensorType key_type, const TensorType value_type,
|
||||
int lookup_size) {
|
||||
HashtableFindOpModel(const TensorType key_type, const TensorType value_type,
|
||||
int lookup_size) {
|
||||
key_type_ = key_type;
|
||||
value_type_ = value_type;
|
||||
|
||||
@ -787,8 +793,8 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
|
||||
|
||||
output_ = AddOutput({value_type, {lookup_size}});
|
||||
|
||||
SetCustomOp("HASHTABLE_LOOKUP", {},
|
||||
tflite::ops::custom::Register_HASHTABLE_LOOKUP);
|
||||
SetCustomOp("HASHTABLE_FIND", {},
|
||||
tflite::ops::custom::Register_HASHTABLE_FIND);
|
||||
BuildInterpreter(
|
||||
{GetShape(resource_id_), GetShape(lookup_), GetShape(default_value_)});
|
||||
}
|
||||
@ -797,46 +803,56 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
|
||||
PopulateTensor(lookup_, data);
|
||||
}
|
||||
|
||||
void SetStringLookup(const std::vector<std::string>& data) {
|
||||
PopulateStringTensor(lookup_, data);
|
||||
}
|
||||
|
||||
void SetDefaultValue(const std::vector<ValueType>& data) {
|
||||
PopulateTensor(default_value_, data);
|
||||
}
|
||||
|
||||
void SetStringDefaultValue(const std::vector<std::string>& data) {
|
||||
PopulateStringTensor(default_value_, data);
|
||||
}
|
||||
|
||||
private:
|
||||
int lookup_;
|
||||
int default_value_;
|
||||
};
|
||||
|
||||
TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
|
||||
TEST(HashtableOpsTest, TestHashtableLookupStringToInt64) {
|
||||
const int kResourceId = 42;
|
||||
HashtableLookupOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
||||
TensorType_INT32, 3);
|
||||
HashtableFindOpModel<std::string, std::int64_t> m(TensorType_STRING,
|
||||
TensorType_INT64, 3);
|
||||
|
||||
m.SetResourceId({kResourceId});
|
||||
m.SetLookup({5, 6, 7});
|
||||
m.SetStringLookup({"5", "6", "7"});
|
||||
m.SetDefaultValue({4});
|
||||
|
||||
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
|
||||
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
||||
InitHashtableResource<std::string, std::int64_t>(
|
||||
&m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
|
||||
{"4", "5", "6"}, {1, 2, 3});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({2, 3, 4}));
|
||||
EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({2, 3, 4}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) {
|
||||
TEST(HashtableOpsTest, TestHashtableLookupInt64ToString) {
|
||||
const int kResourceId = 42;
|
||||
HashtableLookupOpModel<std::int32_t, float> m(TensorType_INT32,
|
||||
TensorType_FLOAT32, 3);
|
||||
HashtableFindOpModel<std::int64_t, std::string> m(TensorType_INT64,
|
||||
TensorType_STRING, 3);
|
||||
|
||||
m.SetResourceId({kResourceId});
|
||||
m.SetLookup({5, 6, 7});
|
||||
m.SetDefaultValue({4.0f});
|
||||
m.SetStringDefaultValue({"4"});
|
||||
|
||||
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
|
||||
kTfLiteFloat32, {4, 5, 6}, {1.0f, 2.0f, 3.0f});
|
||||
InitHashtableResource<std::int64_t, std::string>(
|
||||
&m.GetResources(), kResourceId, kTfLiteInt64, kTfLiteString, {4, 5, 6},
|
||||
{"1", "2", "3"});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({2.0f, 3.0f, 4.0f}));
|
||||
EXPECT_THAT(m.GetOutput<std::string>(), ElementsAreArray({"2", "3", "4"}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
||||
}
|
||||
|
||||
@ -863,19 +879,27 @@ class HashtableImportOpModel : public BaseHashtableOpModel {
|
||||
PopulateTensor(keys_, data);
|
||||
}
|
||||
|
||||
void SetStringKeys(const std::vector<std::string>& data) {
|
||||
PopulateStringTensor(keys_, data);
|
||||
}
|
||||
|
||||
void SetValues(const std::vector<ValueType>& data) {
|
||||
PopulateTensor(values_, data);
|
||||
}
|
||||
|
||||
void SetStringValues(const std::vector<std::string>& data) {
|
||||
PopulateStringTensor(values_, data);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(HashtableOpsTest, TestHashtableImport) {
|
||||
const int kResourceId = 42;
|
||||
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
|
||||
TensorType_FLOAT32, 3);
|
||||
HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
|
||||
TensorType_STRING, 3);
|
||||
EXPECT_EQ(m.GetResources().size(), 0);
|
||||
m.SetResourceId({kResourceId});
|
||||
m.SetKeys({1, 2, 3});
|
||||
m.SetValues({1.0f, 2.0f, 3.0f});
|
||||
m.SetStringValues({"1", "2", "3"});
|
||||
m.CreateHashtableResource(kResourceId);
|
||||
m.Invoke();
|
||||
|
||||
@ -883,20 +907,20 @@ TEST(HashtableOpsTest, TestHashtableImport) {
|
||||
EXPECT_EQ(resources.size(), 1);
|
||||
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
|
||||
EXPECT_TRUE(hashtable != nullptr);
|
||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
|
||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32);
|
||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
|
||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
|
||||
|
||||
EXPECT_EQ(hashtable->Size(), 3);
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestHashtableImportTwice) {
|
||||
const int kResourceId = 42;
|
||||
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
|
||||
TensorType_FLOAT32, 3);
|
||||
HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
|
||||
TensorType_STRING, 3);
|
||||
EXPECT_EQ(m.GetResources().size(), 0);
|
||||
m.SetResourceId({kResourceId});
|
||||
m.SetKeys({1, 2, 3});
|
||||
m.SetValues({1.0f, 2.0f, 3.0f});
|
||||
m.SetStringValues({"1", "2", "3"});
|
||||
m.CreateHashtableResource(kResourceId);
|
||||
m.Invoke();
|
||||
m.Invoke();
|
||||
@ -905,8 +929,8 @@ TEST(HashtableOpsTest, TestHashtableImportTwice) {
|
||||
EXPECT_EQ(resources.size(), 1);
|
||||
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
|
||||
EXPECT_TRUE(hashtable != nullptr);
|
||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
|
||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32);
|
||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
|
||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
|
||||
EXPECT_EQ(hashtable->Size(), 3);
|
||||
}
|
||||
|
||||
@ -920,7 +944,7 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
|
||||
|
||||
resource_id_ = AddInput({TensorType_INT32, {1}});
|
||||
|
||||
output_ = AddOutput({TensorType_INT32, {1}});
|
||||
output_ = AddOutput({TensorType_INT64, {1}});
|
||||
|
||||
SetCustomOp("HASHTABLE_SIZE", {},
|
||||
tflite::ops::custom::Register_HASHTABLE_SIZE);
|
||||
@ -930,23 +954,24 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
|
||||
|
||||
TEST(HashtableOpsTest, TestHashtableSize) {
|
||||
const int kResourceId = 42;
|
||||
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
|
||||
TensorType_INT64);
|
||||
|
||||
m.SetResourceId({kResourceId});
|
||||
|
||||
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
|
||||
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
||||
InitHashtableResource<std::string, std::int64_t>(
|
||||
&m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
|
||||
{"4", "5", "6"}, {1, 2, 3});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({3}));
|
||||
EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({3}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||
}
|
||||
|
||||
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
|
||||
const int kResourceId = 42;
|
||||
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
||||
TensorType_INT32);
|
||||
HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
|
||||
TensorType_INT64);
|
||||
m.SetResourceId({kResourceId});
|
||||
|
||||
// Invoke without hash table initialization.
|
||||
|
@ -40,7 +40,7 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output_tensor != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt64);
|
||||
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
|
||||
outputSize->data[0] = 1;
|
||||
return context->ResizeTensor(context, output_tensor, outputSize);
|
||||
@ -52,7 +52,7 @@ TfLiteStatus EvalHashtableSize(TfLiteContext* context, TfLiteNode* node) {
|
||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||
|
||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||
auto* output_data = GetTensorData<int>(output_tensor);
|
||||
auto* output_data = GetTensorData<std::int64_t>(output_tensor);
|
||||
|
||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
auto& resources = subgraph->resources();
|
||||
@ -69,9 +69,7 @@ TfLiteRegistration* Register_HASHTABLE_SIZE() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
hashtable::PrepareHashtableSize,
|
||||
hashtable::EvalHashtableSize,
|
||||
nullptr,
|
||||
BuiltinOperator_CUSTOM};
|
||||
hashtable::EvalHashtableSize};
|
||||
return &r;
|
||||
}
|
||||
|
||||
|
@ -80,33 +80,14 @@ TfLiteStatus StaticHashtable<KeyType, ValueType>::Import(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename KeyType>
|
||||
LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
|
||||
TfLiteType value_type) {
|
||||
switch (value_type) {
|
||||
case kTfLiteInt32:
|
||||
return new StaticHashtable<KeyType, std::int32_t>(key_type, value_type);
|
||||
case kTfLiteString:
|
||||
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
|
||||
case kTfLiteFloat32:
|
||||
return new StaticHashtable<KeyType, float>(key_type, value_type);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
LookupInterface* CreateStaticHashtable(TfLiteType key_type,
|
||||
TfLiteType value_type) {
|
||||
switch (key_type) {
|
||||
case kTfLiteInt32:
|
||||
return CreateStaticHashtableWithGivenKey<std::int32_t>(key_type,
|
||||
value_type);
|
||||
case kTfLiteString:
|
||||
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
|
||||
value_type);
|
||||
default:
|
||||
return nullptr;
|
||||
if (key_type == kTfLiteInt64 && value_type == kTfLiteString) {
|
||||
return new StaticHashtable<std::int64_t, std::string>(key_type, value_type);
|
||||
} else if (key_type == kTfLiteString && value_type == kTfLiteInt64) {
|
||||
return new StaticHashtable<std::string, std::int64_t>(key_type, value_type);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
@ -588,6 +588,7 @@ cc_library(
|
||||
":op_macros",
|
||||
"//tensorflow/lite:context",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/experimental/kernels:hashtable_op_kernels",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//third_party/fft2d:fft2d_headers",
|
||||
|
@ -22,7 +22,10 @@ namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
|
||||
TfLiteRegistration* Register_HASHTABLE();
|
||||
TfLiteRegistration* Register_HASHTABLE_FIND();
|
||||
TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
||||
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -322,6 +322,15 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
||||
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
||||
buildinop_resolver_->AddCustom("RFFT2D",
|
||||
tflite::ops::custom::Register_RFFT2D());
|
||||
buildinop_resolver_->AddCustom("HashTableV2",
|
||||
tflite::ops::custom::Register_HASHTABLE());
|
||||
buildinop_resolver_->AddCustom(
|
||||
"LookupTableFindV2", tflite::ops::custom::Register_HASHTABLE_FIND());
|
||||
buildinop_resolver_->AddCustom(
|
||||
"LookupTableImportV2",
|
||||
tflite::ops::custom::Register_HASHTABLE_IMPORT());
|
||||
buildinop_resolver_->AddCustom(
|
||||
"LookupTableSizeV2", tflite::ops::custom::Register_HASHTABLE_SIZE());
|
||||
}
|
||||
|
||||
switch (delegate_type) {
|
||||
|
Loading…
Reference in New Issue
Block a user