Make Hashtable ops in TFLite compatible with MLIR converter

PiperOrigin-RevId: 288628649
Change-Id: Ie7dbd06e1e277e33da5d08993d27d08920837f3a
This commit is contained in:
Jaesung Chung 2020-01-07 21:05:24 -08:00 committed by TensorFlower Gardener
parent 7c965ff357
commit 57b4fbbfdb
10 changed files with 219 additions and 202 deletions

View File

@ -130,13 +130,14 @@ cc_library(
name = "hashtable_op_kernels", name = "hashtable_op_kernels",
srcs = [ srcs = [
"hashtable.cc", "hashtable.cc",
"hashtable_find.cc",
"hashtable_import.cc", "hashtable_import.cc",
"hashtable_lookup.cc",
"hashtable_size.cc", "hashtable_size.cc",
], ],
deps = [ deps = [
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/experimental/resource", "//tensorflow/lite/experimental/resource",
"//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:op_macros",

View File

@ -15,7 +15,9 @@ limitations under the License.
#include <string> #include <string>
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers #include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h" #include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -26,7 +28,10 @@ namespace ops {
namespace custom { namespace custom {
namespace hashtable { namespace hashtable {
constexpr int kResourceHandleTensor = 0; static constexpr int kResourceHandleTensor = 0;
static constexpr const char kSharedNameStr[] = "shared_name";
static constexpr const char kKeyDtypeStr[] = "key_dtype";
static constexpr const char kValueDtypeStr[] = "value_dtype";
// TODO(b/144728911): The following structure should be moved to // TODO(b/144728911): The following structure should be moved to
// builtin_op_data.h when it is ready to become a builtin op. // builtin_op_data.h when it is ready to become a builtin op.
@ -41,11 +46,18 @@ void* InitHashtable(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
const std::string table_name = m[kSharedNameStr].AsString().str();
TfLiteType key_dtype, value_dtype;
ConvertTensorType(static_cast<TensorType>(m[kKeyDtypeStr].AsInt32()),
&key_dtype, nullptr);
ConvertTensorType(static_cast<TensorType>(m[kValueDtypeStr].AsInt32()),
&value_dtype, nullptr);
TfLiteHashtableParams* option = new TfLiteHashtableParams; TfLiteHashtableParams* option = new TfLiteHashtableParams;
option->table_name = m["table_name"].AsString().str(); option->table_name = table_name;
option->key_dtype = static_cast<TfLiteType>(m["key_dtype"].AsInt32()); option->key_dtype = key_dtype;
option->value_dtype = static_cast<TfLiteType>(m["value_dtype"].AsInt32()); option->value_dtype = value_dtype;
return option; return option;
} }
@ -61,12 +73,12 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->user_data != nullptr); TF_LITE_ENSURE(context, node->user_data != nullptr);
const auto* params = const auto* params =
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data); reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
TF_LITE_ENSURE(context, !params->table_name.empty()); TF_LITE_ENSURE(context, !params->table_name.empty());
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt32 || TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt64 &&
params->key_dtype == kTfLiteString)); params->value_dtype == kTfLiteString) ||
TF_LITE_ENSURE(context, (params->value_dtype == kTfLiteInt32 || (params->key_dtype == kTfLiteString &&
params->value_dtype == kTfLiteString || params->value_dtype == kTfLiteInt64));
params->value_dtype == kTfLiteFloat32));
TfLiteTensor* resource_handle_tensor = TfLiteTensor* resource_handle_tensor =
GetOutput(context, node, kResourceHandleTensor); GetOutput(context, node, kResourceHandleTensor);
@ -78,6 +90,7 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
} }
TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->user_data != nullptr);
const auto* params = const auto* params =
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data); reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
@ -100,12 +113,9 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
} // namespace hashtable } // namespace hashtable
TfLiteRegistration* Register_HASHTABLE() { TfLiteRegistration* Register_HASHTABLE() {
static TfLiteRegistration r = {hashtable::InitHashtable, static TfLiteRegistration r = {
hashtable::FreeHashtable, hashtable::InitHashtable, hashtable::FreeHashtable,
hashtable::PrepareHashtable, hashtable::PrepareHashtable, hashtable::EvalHashtable};
hashtable::EvalHashtable,
nullptr,
BuiltinOperator_CUSTOM};
return &r; return &r;
} }

View File

@ -30,7 +30,7 @@ constexpr int kKeyTensor = 1;
constexpr int kDefaultValueTensor = 2; constexpr int kDefaultValueTensor = 2;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@ -42,26 +42,19 @@ TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* default_value_tensor = const TfLiteTensor* default_value_tensor =
GetInput(context, node, kDefaultValueTensor); GetInput(context, node, kDefaultValueTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(default_value_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(default_value_tensor, 0), 1);
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
TF_LITE_ENSURE(context, (output_tensor->type == kTfLiteInt32 ||
output_tensor->type == kTfLiteString ||
output_tensor->type == kTfLiteFloat32));
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor); const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 || TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
key_tensor->type == kTfLiteString)); TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
if (output_tensor->type != kTfLiteString) { TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
return context->ResizeTensor(context, output_tensor, output_tensor->type == kTfLiteString) ||
TfLiteIntArrayCopy(key_tensor->dims)); (key_tensor->type == kTfLiteString &&
} output_tensor->type == kTfLiteInt64));
return kTfLiteOk; return context->ResizeTensor(context, output_tensor,
TfLiteIntArrayCopy(key_tensor->dims));
} }
TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus EvalHashtableFind(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_resource_id_tensor = const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor); GetInput(context, node, kInputResourceIdTensor);
int resource_id = input_resource_id_tensor->data.i32[0]; int resource_id = input_resource_id_tensor->data.i32[0];
@ -77,19 +70,18 @@ TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, lookup != nullptr); TF_LITE_ENSURE(context, lookup != nullptr);
TF_LITE_ENSURE_STATUS( TF_LITE_ENSURE_STATUS(
lookup->CheckKeyAndValueTypes(context, key_tensor, output_tensor)); lookup->CheckKeyAndValueTypes(context, key_tensor, output_tensor));
return lookup->Lookup(context, key_tensor, output_tensor, auto result =
default_value_tensor); lookup->Lookup(context, key_tensor, output_tensor, default_value_tensor);
return result;
} }
} // namespace hashtable } // namespace hashtable
TfLiteRegistration* Register_HASHTABLE_LOOKUP() { TfLiteRegistration* Register_HASHTABLE_FIND() {
static TfLiteRegistration r = {/*init=*/nullptr, static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr, /*free=*/nullptr,
hashtable::PrepareHashtableLookup, hashtable::PrepareHashtableFind,
hashtable::EvalHashtableLookup, hashtable::EvalHashtableFind};
nullptr,
BuiltinOperator_CUSTOM};
return &r; return &r;
} }

View File

@ -40,13 +40,11 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1); TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor); const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
key_tensor->type == kTfLiteString));
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor); const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
TF_LITE_ENSURE(context, (value_tensor->type == kTfLiteInt32 || TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
value_tensor->type == kTfLiteString || value_tensor->type == kTfLiteString) ||
value_tensor->type == kTfLiteFloat32)); (key_tensor->type == kTfLiteString &&
value_tensor->type == kTfLiteInt64));
// TODO(b/144731295): Tensorflow lookup ops support 1-D vector in storing // TODO(b/144731295): Tensorflow lookup ops support 1-D vector in storing
// values. // values.
TF_LITE_ENSURE(context, HaveSameShapes(key_tensor, value_tensor)); TF_LITE_ENSURE(context, HaveSameShapes(key_tensor, value_tensor));
@ -69,7 +67,8 @@ TfLiteStatus EvalHashtableImport(TfLiteContext* context, TfLiteNode* node) {
lookup->CheckKeyAndValueTypes(context, key_tensor, value_tensor)); lookup->CheckKeyAndValueTypes(context, key_tensor, value_tensor));
// The hashtable resource will only be initialized once, attempting to // The hashtable resource will only be initialized once, attempting to
// initialize it multiple times will be a no-op. // initialize it multiple times will be a no-op.
return lookup->Import(context, key_tensor, value_tensor); auto result = lookup->Import(context, key_tensor, value_tensor);
return result;
} }
} // namespace hashtable } // namespace hashtable
@ -78,9 +77,7 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT() {
static TfLiteRegistration r = {/*init=*/nullptr, static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr, /*free=*/nullptr,
hashtable::PrepareHashtableImport, hashtable::PrepareHashtableImport,
hashtable::EvalHashtableImport, hashtable::EvalHashtableImport};
nullptr,
BuiltinOperator_CUSTOM};
return &r; return &r;
} }

View File

@ -34,7 +34,7 @@ namespace ops {
namespace custom { namespace custom {
TfLiteRegistration* Register_HASHTABLE(); TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_LOOKUP(); TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT(); TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE(); TfLiteRegistration* Register_HASHTABLE_SIZE();
@ -45,6 +45,10 @@ namespace {
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
static constexpr const char kSharedNameStr[] = "shared_name";
static constexpr const char kKeyDtypeStr[] = "key_dtype";
static constexpr const char kValueDtypeStr[] = "value_dtype";
typedef enum { typedef enum {
kResourceTensorId = 0, kResourceTensorId = 0,
kKeyTensorId = 1, kKeyTensorId = 1,
@ -84,6 +88,19 @@ void SetTensorData(Interpreter* interpreter, int tensorId,
buf.WriteToTensorAsVector(tensor); buf.WriteToTensorAsVector(tensor);
} }
TensorType ConvertTfLiteType(TfLiteType type) {
// Currently, hashtable kernels support INT64 and STRING types only.
switch (type) {
case kTfLiteInt64:
return TensorType_INT64;
case kTfLiteString:
return TensorType_STRING;
default:
CHECK(false); // Not reached.
return TensorType_MIN;
}
}
// HashtableGraph generates a graph with hash table ops. This class can create // HashtableGraph generates a graph with hash table ops. This class can create
// the following scenarios: // the following scenarios:
// //
@ -120,7 +137,7 @@ class HashtableGraph {
// Hash table lookup node. // Hash table lookup node.
interpreter_->AddNodeWithParameters( interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId}, {kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_, {kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index); &node_index);
// Hash table size node. // Hash table size node.
@ -142,7 +159,7 @@ class HashtableGraph {
// Hash table lookup node. // Hash table lookup node.
interpreter_->AddNodeWithParameters( interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId}, {kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_, {kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index); &node_index);
// Hash table size node. // Hash table size node.
@ -174,7 +191,7 @@ class HashtableGraph {
// Hash table lookup node. // Hash table lookup node.
interpreter_->AddNodeWithParameters( interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId}, {kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_, {kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index); &node_index);
// Hash table size node. // Hash table size node.
@ -201,7 +218,7 @@ class HashtableGraph {
// Hash table lookup node. // Hash table lookup node.
interpreter_->AddNodeWithParameters( interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId}, {kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_, {kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index); &node_index);
// Hash table size node. // Hash table size node.
@ -226,8 +243,8 @@ class HashtableGraph {
// Hash table two lookup node. // Hash table two lookup node.
interpreter_->AddNodeWithParameters( interpreter_->AddNodeWithParameters(
{kResourceTwoTensorId, kQueryTwoTensorId, kDefaultValueTwoTensorId}, {kResourceTwoTensorId, kQueryTwoTensorId, kDefaultValueTwoTensorId},
{kResultTwoTensorId}, nullptr, 0, nullptr, {kResultTwoTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
hashtable_lookup_registration_, &node_index); &node_index);
// Hash table two size node. // Hash table two size node.
interpreter_->AddNodeWithParameters( interpreter_->AddNodeWithParameters(
@ -261,16 +278,16 @@ class HashtableGraph {
default_value_two_ = default_value; default_value_two_ = default_value;
} }
int GetTableSize() { int64_t GetTableSize() {
auto* size_tensor = interpreter_->tensor(kSizeTensorId); auto* size_tensor = interpreter_->tensor(kSizeTensorId);
auto size_tensor_shape = GetTensorShape(size_tensor); auto size_tensor_shape = GetTensorShape(size_tensor);
return GetTensorData<int>(size_tensor)[0]; return GetTensorData<int64_t>(size_tensor)[0];
} }
int GetTableTwoSize() { int64_t GetTableTwoSize() {
auto* size_tensor = interpreter_->tensor(kSizeTwoTensorId); auto* size_tensor = interpreter_->tensor(kSizeTwoTensorId);
auto size_tensor_shape = GetTensorShape(size_tensor); auto size_tensor_shape = GetTensorShape(size_tensor);
return GetTensorData<int>(size_tensor)[0]; return GetTensorData<int64_t>(size_tensor)[0];
} }
std::vector<ValueType> GetLookupResult() { std::vector<ValueType> GetLookupResult() {
@ -363,7 +380,7 @@ class HashtableGraph {
TfLiteQuantization()); TfLiteQuantization());
// Result tensor for size calculation. // Result tensor for size calculation.
interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt32, "", interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt64, "",
{1}, TfLiteQuantization()); {1}, TfLiteQuantization());
// Default value tensor for lookup. // Default value tensor for lookup.
@ -396,7 +413,7 @@ class HashtableGraph {
{static_cast<int>(queries_two_.size())}, TfLiteQuantization()); {static_cast<int>(queries_two_.size())}, TfLiteQuantization());
// Result tensor for size calculation. // Result tensor for size calculation.
interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt32, interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt64,
"", {1}, TfLiteQuantization()); "", {1}, TfLiteQuantization());
// Default value tensor for lookup. // Default value tensor for lookup.
@ -433,9 +450,9 @@ class HashtableGraph {
hashtable_registration_ = tflite::ops::custom::Register_HASHTABLE(); hashtable_registration_ = tflite::ops::custom::Register_HASHTABLE();
ASSERT_NE(hashtable_registration_, nullptr); ASSERT_NE(hashtable_registration_, nullptr);
hashtable_lookup_registration_ = hashtable_find_registration_ =
tflite::ops::custom::Register_HASHTABLE_LOOKUP(); tflite::ops::custom::Register_HASHTABLE_FIND();
ASSERT_NE(hashtable_lookup_registration_, nullptr); ASSERT_NE(hashtable_find_registration_, nullptr);
hashtable_import_registration_ = hashtable_import_registration_ =
tflite::ops::custom::Register_HASHTABLE_IMPORT(); tflite::ops::custom::Register_HASHTABLE_IMPORT();
@ -447,11 +464,15 @@ class HashtableGraph {
} }
std::vector<uint8_t> GetHashtableParamsInFlatbuffer() { std::vector<uint8_t> GetHashtableParamsInFlatbuffer() {
TensorType key_tensor_type = ConvertTfLiteType(key_type_);
TensorType value_tensor_type = ConvertTfLiteType(value_type_);
flexbuffers::Builder fbb; flexbuffers::Builder fbb;
fbb.Map([&]() { fbb.Map([&]() {
fbb.String("table_name", "test_table_name" + std::to_string(std::rand())); fbb.String(kSharedNameStr,
fbb.Int("key_dtype", key_type_); "test_table_name" + std::to_string(std::rand()));
fbb.Int("value_dtype", value_type_); fbb.Int(kKeyDtypeStr, key_tensor_type);
fbb.Int(kValueDtypeStr, value_tensor_type);
}); });
fbb.Finish(); fbb.Finish();
return fbb.GetBuffer(); return fbb.GetBuffer();
@ -475,7 +496,7 @@ class HashtableGraph {
// Op registrations. // Op registrations.
TfLiteRegistration* hashtable_registration_; TfLiteRegistration* hashtable_registration_;
TfLiteRegistration* hashtable_lookup_registration_; TfLiteRegistration* hashtable_find_registration_;
TfLiteRegistration* hashtable_import_registration_; TfLiteRegistration* hashtable_import_registration_;
TfLiteRegistration* hashtable_size_registration_; TfLiteRegistration* hashtable_size_registration_;
@ -539,64 +560,27 @@ class HashtableDefaultGraphTest {
std::vector<ValueType> lookup_result_; std::vector<ValueType> lookup_result_;
}; };
TEST(HashtableOpsTest, TestInt32ToInt32Hashtable) { TEST(HashtableOpsTest, TestInt64ToStringHashtable) {
HashtableDefaultGraphTest<int, int> t( HashtableDefaultGraphTest<std::int64_t, std::string> t(
kTfLiteInt32, kTfLiteInt32, kTfLiteInt64, kTfLiteString,
/*keys=*/{1, 2, 3}, /*values=*/{4, 5, 6}, /*queries=*/{2, 3, 4},
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
t.InvokeAndVerifyIntResult();
}
TEST(HashtableOpsTest, TestInt32ToFloat32Hashtable) {
HashtableDefaultGraphTest<int, float> t(
kTfLiteInt32, kTfLiteFloat32,
/*keys=*/{1, 2, 3}, /*values=*/{4.0f, 5.0f, 6.0f}, /*queries=*/{2, 3, 4},
/*default_value=*/-1.0f, /*table_size=*/3,
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
t.InvokeAndVerifyFloatResult();
}
TEST(HashtableOpsTest, TestInt32ToStringHashtable) {
HashtableDefaultGraphTest<int, std::string> t(
kTfLiteInt32, kTfLiteString,
/*keys=*/{1, 2, 3}, /*values=*/{"a", "b", "c"}, /*queries=*/{2, 3, 4}, /*keys=*/{1, 2, 3}, /*values=*/{"a", "b", "c"}, /*queries=*/{2, 3, 4},
/*default_value=*/"d", /*table_size=*/3, /*default_value=*/"d", /*table_size=*/3,
/*lookup_result=*/{"b", "c", "d"}); /*lookup_result=*/{"b", "c", "d"});
t.InvokeAndVerifyStringResult(); t.InvokeAndVerifyStringResult();
} }
TEST(HashtableOpsTest, TestStringToInt32Hashtable) { TEST(HashtableOpsTest, TestStringToInt64Hashtable) {
HashtableDefaultGraphTest<std::string, int> t( HashtableDefaultGraphTest<std::string, int64_t> t(
kTfLiteString, kTfLiteInt32, kTfLiteString, kTfLiteInt64,
/*keys=*/{"A", "B", "C"}, /*values=*/{4, 5, 6}, /*keys=*/{"A", "B", "C"}, /*values=*/{4, 5, 6},
/*queries=*/{"B", "C", "D"}, /*queries=*/{"B", "C", "D"},
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1}); /*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
t.InvokeAndVerifyIntResult(); t.InvokeAndVerifyIntResult();
} }
TEST(HashtableOpsTest, TestStringToFloat32Hashtable) {
HashtableDefaultGraphTest<std::string, float> t(
kTfLiteString, kTfLiteFloat32,
/*keys=*/{"A", "B", "C"}, /*values=*/{4.0f, 5.0f, 6.0f},
/*queries=*/{"B", "C", "D"},
/*default_value=*/-1.0f, /*table_size=*/3,
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
t.InvokeAndVerifyFloatResult();
}
TEST(HashtableOpsTest, TestStringToStringHashtable) {
HashtableDefaultGraphTest<std::string, std::string> t(
kTfLiteString, kTfLiteString,
/*keys=*/{"A", "B", "C"}, /*values=*/{"a", "b", "c"},
/*queries=*/{"B", "C", "D"},
/*default_value=*/"d", /*table_size=*/3,
/*lookup_result=*/{"b", "c", "d"});
t.InvokeAndVerifyStringResult();
}
TEST(HashtableOpsTest, TestNoImport) { TEST(HashtableOpsTest, TestNoImport) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32); HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetQuery({1, 2, 3}, -1); graph.SetQuery({"1", "2", "3"}, -1);
graph.AddTensors(); graph.AddTensors();
graph.BuildNoImportGraph(); graph.BuildNoImportGraph();
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk); EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
@ -607,9 +591,9 @@ TEST(HashtableOpsTest, TestNoImport) {
} }
TEST(HashtableOpsTest, TestImportTwice) { TEST(HashtableOpsTest, TestImportTwice) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32); HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetTable({1, 2, 3}, {4, 5, 6}); graph.SetTable({"1", "2", "3"}, {4, 5, 6});
graph.SetQuery({2, 3, 4}, -1); graph.SetQuery({"2", "3", "4"}, -1);
graph.AddTensors(); graph.AddTensors();
graph.BuildImportTwiceGraph(); graph.BuildImportTwiceGraph();
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk); EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
@ -621,11 +605,11 @@ TEST(HashtableOpsTest, TestImportTwice) {
} }
TEST(HashtableOpsTest, TestTwoHashtables) { TEST(HashtableOpsTest, TestTwoHashtables) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32); HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetTable({1, 2, 3}, {4, 5, 6}); graph.SetTable({"1", "2", "3"}, {4, 5, 6});
graph.SetQuery({2, 3, 4}, -1); graph.SetQuery({"2", "3", "4"}, -1);
graph.SetTableTwo({-1, -2, -3}, {7, 8, 9}); graph.SetTableTwo({"-1", "-2", "-3"}, {7, 8, 9});
graph.SetQueryForTableTwo({-4, -2, -3}, -2); graph.SetQueryForTableTwo({"-4", "-2", "-3"}, -2);
graph.AddTensors(/*table_two_initialization=*/true); graph.AddTensors(/*table_two_initialization=*/true);
graph.BuildTwoHashtablesGraph(); graph.BuildTwoHashtablesGraph();
EXPECT_EQ(graph.AllocateTensors(/*table_two_initialization=*/true), EXPECT_EQ(graph.AllocateTensors(/*table_two_initialization=*/true),
@ -639,9 +623,9 @@ TEST(HashtableOpsTest, TestTwoHashtables) {
} }
TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) { TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32); HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetTable({1, 2, 3}, {4, 5}); graph.SetTable({"1", "2", "3"}, {4, 5});
graph.SetQuery({2, 3, 4}, -1); graph.SetQuery({"2", "3", "4"}, -1);
graph.AddTensors(); graph.AddTensors();
graph.BuildDefaultGraph(); graph.BuildDefaultGraph();
EXPECT_EQ(graph.AllocateTensors(), kTfLiteError); EXPECT_EQ(graph.AllocateTensors(), kTfLiteError);
@ -650,16 +634,16 @@ TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
// HashtableOpModel creates a model with one signle Hashtable op. // HashtableOpModel creates a model with one signle Hashtable op.
class HashtableOpModel : public SingleOpModel { class HashtableOpModel : public SingleOpModel {
public: public:
explicit HashtableOpModel(const char* table_name, TfLiteType key_dtype, explicit HashtableOpModel(const char* table_name, TensorType key_dtype,
TfLiteType value_dtype) { TensorType value_dtype) {
output_ = AddOutput(GetTensorType<int>()); output_ = AddOutput(GetTensorType<int>());
// Set up and pass in custom options using flexbuffer. // Set up and pass in custom options using flexbuffer.
flexbuffers::Builder fbb; flexbuffers::Builder fbb;
fbb.Map([&]() { fbb.Map([&]() {
fbb.String("table_name", std::string(table_name)); fbb.String(kSharedNameStr, std::string(table_name));
fbb.Int("key_dtype", key_dtype); fbb.Int(kKeyDtypeStr, key_dtype);
fbb.Int("value_dtype", value_dtype); fbb.Int(kValueDtypeStr, value_dtype);
}); });
fbb.Finish(); fbb.Finish();
SetCustomOp("HASHTABLE", fbb.GetBuffer(), SetCustomOp("HASHTABLE", fbb.GetBuffer(),
@ -679,7 +663,7 @@ class HashtableOpModel : public SingleOpModel {
}; };
TEST(HashtableOpsTest, TestHashtable) { TEST(HashtableOpsTest, TestHashtable) {
HashtableOpModel m("test_hashtable", kTfLiteInt32, kTfLiteString); HashtableOpModel m("test_hashtable", TensorType_INT64, TensorType_STRING);
EXPECT_EQ(m.GetResources().size(), 0); EXPECT_EQ(m.GetResources().size(), 0);
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
@ -689,12 +673,12 @@ TEST(HashtableOpsTest, TestHashtable) {
EXPECT_NE(resource_id, 0); EXPECT_NE(resource_id, 0);
auto* hashtable = resource::GetHashtableResource(&resources, resource_id); auto* hashtable = resource::GetHashtableResource(&resources, resource_id);
EXPECT_TRUE(hashtable != nullptr); EXPECT_TRUE(hashtable != nullptr);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32); EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString); EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
} }
template <typename T> template <typename T>
TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) { TfLiteTensor CreateTensor(TfLiteType type, const std::vector<T>& vec) {
TfLiteTensor tensor = {}; TfLiteTensor tensor = {};
TfLiteIntArray* dims = TfLiteIntArrayCreate(1); TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
dims->data[0] = vec.size(); dims->data[0] = vec.size();
@ -715,6 +699,28 @@ TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) {
return tensor; return tensor;
} }
template <>
TfLiteTensor CreateTensor(TfLiteType type,
const std::vector<std::string>& vec) {
TfLiteTensor tensor = {};
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
dims->data[0] = vec.size();
tensor.dims = dims;
tensor.name = "";
tensor.params = {};
tensor.quantization = {kTfLiteNoQuantization, nullptr};
tensor.is_variable = false;
tensor.allocation_type = kTfLiteDynamic;
tensor.allocation = nullptr;
tensor.type = type;
DynamicBuffer buf;
for (std::string str : vec) {
buf.AddString(str.c_str(), str.size());
}
buf.WriteToTensor(&tensor, nullptr);
return tensor;
}
template <typename KeyType, typename ValueType> template <typename KeyType, typename ValueType>
void InitHashtableResource(resource::ResourceMap* resources, int resource_id, void InitHashtableResource(resource::ResourceMap* resources, int resource_id,
TfLiteType key_type, TfLiteType value_type, TfLiteType key_type, TfLiteType value_type,
@ -772,12 +778,12 @@ class BaseHashtableOpModel : public SingleOpModel {
TensorType value_type_; TensorType value_type_;
}; };
// HashtableLookupOpModel creates a model with a HashtableLookup op. // HashtableFindOpModel creates a model with a HashtableLookup op.
template <typename KeyType, typename ValueType> template <typename KeyType, typename ValueType>
class HashtableLookupOpModel : public BaseHashtableOpModel { class HashtableFindOpModel : public BaseHashtableOpModel {
public: public:
HashtableLookupOpModel(const TensorType key_type, const TensorType value_type, HashtableFindOpModel(const TensorType key_type, const TensorType value_type,
int lookup_size) { int lookup_size) {
key_type_ = key_type; key_type_ = key_type;
value_type_ = value_type; value_type_ = value_type;
@ -787,8 +793,8 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
output_ = AddOutput({value_type, {lookup_size}}); output_ = AddOutput({value_type, {lookup_size}});
SetCustomOp("HASHTABLE_LOOKUP", {}, SetCustomOp("HASHTABLE_FIND", {},
tflite::ops::custom::Register_HASHTABLE_LOOKUP); tflite::ops::custom::Register_HASHTABLE_FIND);
BuildInterpreter( BuildInterpreter(
{GetShape(resource_id_), GetShape(lookup_), GetShape(default_value_)}); {GetShape(resource_id_), GetShape(lookup_), GetShape(default_value_)});
} }
@ -797,46 +803,56 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
PopulateTensor(lookup_, data); PopulateTensor(lookup_, data);
} }
void SetStringLookup(const std::vector<std::string>& data) {
PopulateStringTensor(lookup_, data);
}
void SetDefaultValue(const std::vector<ValueType>& data) { void SetDefaultValue(const std::vector<ValueType>& data) {
PopulateTensor(default_value_, data); PopulateTensor(default_value_, data);
} }
void SetStringDefaultValue(const std::vector<std::string>& data) {
PopulateStringTensor(default_value_, data);
}
private: private:
int lookup_; int lookup_;
int default_value_; int default_value_;
}; };
TEST(HashtableOpsTest, TestHashtableLookupIntToInt) { TEST(HashtableOpsTest, TestHashtableLookupStringToInt64) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableLookupOpModel<std::int32_t, std::int32_t> m(TensorType_INT32, HashtableFindOpModel<std::string, std::int64_t> m(TensorType_STRING,
TensorType_INT32, 3); TensorType_INT64, 3);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetLookup({5, 6, 7}); m.SetStringLookup({"5", "6", "7"});
m.SetDefaultValue({4}); m.SetDefaultValue({4});
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32, InitHashtableResource<std::string, std::int64_t>(
kTfLiteInt32, {4, 5, 6}, {1, 2, 3}); &m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
{"4", "5", "6"}, {1, 2, 3});
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({2, 3, 4})); EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({2, 3, 4}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
} }
TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) { TEST(HashtableOpsTest, TestHashtableLookupInt64ToString) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableLookupOpModel<std::int32_t, float> m(TensorType_INT32, HashtableFindOpModel<std::int64_t, std::string> m(TensorType_INT64,
TensorType_FLOAT32, 3); TensorType_STRING, 3);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetLookup({5, 6, 7}); m.SetLookup({5, 6, 7});
m.SetDefaultValue({4.0f}); m.SetStringDefaultValue({"4"});
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32, InitHashtableResource<std::int64_t, std::string>(
kTfLiteFloat32, {4, 5, 6}, {1.0f, 2.0f, 3.0f}); &m.GetResources(), kResourceId, kTfLiteInt64, kTfLiteString, {4, 5, 6},
{"1", "2", "3"});
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({2.0f, 3.0f, 4.0f})); EXPECT_THAT(m.GetOutput<std::string>(), ElementsAreArray({"2", "3", "4"}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
} }
@ -863,19 +879,27 @@ class HashtableImportOpModel : public BaseHashtableOpModel {
PopulateTensor(keys_, data); PopulateTensor(keys_, data);
} }
void SetStringKeys(const std::vector<std::string>& data) {
PopulateStringTensor(keys_, data);
}
void SetValues(const std::vector<ValueType>& data) { void SetValues(const std::vector<ValueType>& data) {
PopulateTensor(values_, data); PopulateTensor(values_, data);
} }
void SetStringValues(const std::vector<std::string>& data) {
PopulateStringTensor(values_, data);
}
}; };
TEST(HashtableOpsTest, TestHashtableImport) { TEST(HashtableOpsTest, TestHashtableImport) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32, HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
TensorType_FLOAT32, 3); TensorType_STRING, 3);
EXPECT_EQ(m.GetResources().size(), 0); EXPECT_EQ(m.GetResources().size(), 0);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetKeys({1, 2, 3}); m.SetKeys({1, 2, 3});
m.SetValues({1.0f, 2.0f, 3.0f}); m.SetStringValues({"1", "2", "3"});
m.CreateHashtableResource(kResourceId); m.CreateHashtableResource(kResourceId);
m.Invoke(); m.Invoke();
@ -883,20 +907,20 @@ TEST(HashtableOpsTest, TestHashtableImport) {
EXPECT_EQ(resources.size(), 1); EXPECT_EQ(resources.size(), 1);
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId); auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
EXPECT_TRUE(hashtable != nullptr); EXPECT_TRUE(hashtable != nullptr);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32); EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32); EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
EXPECT_EQ(hashtable->Size(), 3); EXPECT_EQ(hashtable->Size(), 3);
} }
TEST(HashtableOpsTest, TestHashtableImportTwice) { TEST(HashtableOpsTest, TestHashtableImportTwice) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32, HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
TensorType_FLOAT32, 3); TensorType_STRING, 3);
EXPECT_EQ(m.GetResources().size(), 0); EXPECT_EQ(m.GetResources().size(), 0);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetKeys({1, 2, 3}); m.SetKeys({1, 2, 3});
m.SetValues({1.0f, 2.0f, 3.0f}); m.SetStringValues({"1", "2", "3"});
m.CreateHashtableResource(kResourceId); m.CreateHashtableResource(kResourceId);
m.Invoke(); m.Invoke();
m.Invoke(); m.Invoke();
@ -905,8 +929,8 @@ TEST(HashtableOpsTest, TestHashtableImportTwice) {
EXPECT_EQ(resources.size(), 1); EXPECT_EQ(resources.size(), 1);
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId); auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
EXPECT_TRUE(hashtable != nullptr); EXPECT_TRUE(hashtable != nullptr);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32); EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32); EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
EXPECT_EQ(hashtable->Size(), 3); EXPECT_EQ(hashtable->Size(), 3);
} }
@ -920,7 +944,7 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
resource_id_ = AddInput({TensorType_INT32, {1}}); resource_id_ = AddInput({TensorType_INT32, {1}});
output_ = AddOutput({TensorType_INT32, {1}}); output_ = AddOutput({TensorType_INT64, {1}});
SetCustomOp("HASHTABLE_SIZE", {}, SetCustomOp("HASHTABLE_SIZE", {},
tflite::ops::custom::Register_HASHTABLE_SIZE); tflite::ops::custom::Register_HASHTABLE_SIZE);
@ -930,23 +954,24 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
TEST(HashtableOpsTest, TestHashtableSize) { TEST(HashtableOpsTest, TestHashtableSize) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32, HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
TensorType_INT32); TensorType_INT64);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32, InitHashtableResource<std::string, std::int64_t>(
kTfLiteInt32, {4, 5, 6}, {1, 2, 3}); &m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
{"4", "5", "6"}, {1, 2, 3});
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({3}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
} }
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) { TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32, HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
TensorType_INT32); TensorType_INT64);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
// Invoke without hash table initialization. // Invoke without hash table initialization.

View File

@ -40,7 +40,7 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor); TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output_tensor != nullptr); TF_LITE_ENSURE(context, output_tensor != nullptr);
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt64);
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
outputSize->data[0] = 1; outputSize->data[0] = 1;
return context->ResizeTensor(context, output_tensor, outputSize); return context->ResizeTensor(context, output_tensor, outputSize);
@ -52,7 +52,7 @@ TfLiteStatus EvalHashtableSize(TfLiteContext* context, TfLiteNode* node) {
int resource_id = input_resource_id_tensor->data.i32[0]; int resource_id = input_resource_id_tensor->data.i32[0];
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor); TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
auto* output_data = GetTensorData<int>(output_tensor); auto* output_data = GetTensorData<std::int64_t>(output_tensor);
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_); Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources(); auto& resources = subgraph->resources();
@ -69,9 +69,7 @@ TfLiteRegistration* Register_HASHTABLE_SIZE() {
static TfLiteRegistration r = {/*init=*/nullptr, static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr, /*free=*/nullptr,
hashtable::PrepareHashtableSize, hashtable::PrepareHashtableSize,
hashtable::EvalHashtableSize, hashtable::EvalHashtableSize};
nullptr,
BuiltinOperator_CUSTOM};
return &r; return &r;
} }

View File

@ -80,33 +80,14 @@ TfLiteStatus StaticHashtable<KeyType, ValueType>::Import(
return kTfLiteOk; return kTfLiteOk;
} }
template <typename KeyType>
LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
TfLiteType value_type) {
switch (value_type) {
case kTfLiteInt32:
return new StaticHashtable<KeyType, std::int32_t>(key_type, value_type);
case kTfLiteString:
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
case kTfLiteFloat32:
return new StaticHashtable<KeyType, float>(key_type, value_type);
default:
return nullptr;
}
}
LookupInterface* CreateStaticHashtable(TfLiteType key_type, LookupInterface* CreateStaticHashtable(TfLiteType key_type,
TfLiteType value_type) { TfLiteType value_type) {
switch (key_type) { if (key_type == kTfLiteInt64 && value_type == kTfLiteString) {
case kTfLiteInt32: return new StaticHashtable<std::int64_t, std::string>(key_type, value_type);
return CreateStaticHashtableWithGivenKey<std::int32_t>(key_type, } else if (key_type == kTfLiteString && value_type == kTfLiteInt64) {
value_type); return new StaticHashtable<std::string, std::int64_t>(key_type, value_type);
case kTfLiteString:
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
value_type);
default:
return nullptr;
} }
return nullptr;
} }
} // namespace internal } // namespace internal

View File

@ -588,6 +588,7 @@ cc_library(
":op_macros", ":op_macros",
"//tensorflow/lite:context", "//tensorflow/lite:context",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/kernels:hashtable_op_kernels",
"//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/kernels/internal:tensor",
"//third_party/fft2d:fft2d_headers", "//third_party/fft2d:fft2d_headers",

View File

@ -22,7 +22,10 @@ namespace ops {
namespace custom { namespace custom {
TfLiteRegistration* Register_RFFT2D(); TfLiteRegistration* Register_RFFT2D();
TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
} }
} // namespace ops } // namespace ops
} // namespace tflite } // namespace tflite

View File

@ -322,6 +322,15 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get()); reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
buildinop_resolver_->AddCustom("RFFT2D", buildinop_resolver_->AddCustom("RFFT2D",
tflite::ops::custom::Register_RFFT2D()); tflite::ops::custom::Register_RFFT2D());
buildinop_resolver_->AddCustom("HashTableV2",
tflite::ops::custom::Register_HASHTABLE());
buildinop_resolver_->AddCustom(
"LookupTableFindV2", tflite::ops::custom::Register_HASHTABLE_FIND());
buildinop_resolver_->AddCustom(
"LookupTableImportV2",
tflite::ops::custom::Register_HASHTABLE_IMPORT());
buildinop_resolver_->AddCustom(
"LookupTableSizeV2", tflite::ops::custom::Register_HASHTABLE_SIZE());
} }
switch (delegate_type) { switch (delegate_type) {