Make Hashtable ops in TFLite compatible with MLIR converter

PiperOrigin-RevId: 288628649
Change-Id: Ie7dbd06e1e277e33da5d08993d27d08920837f3a
This commit is contained in:
Jaesung Chung 2020-01-07 21:05:24 -08:00 committed by TensorFlower Gardener
parent 7c965ff357
commit 57b4fbbfdb
10 changed files with 219 additions and 202 deletions

View File

@ -130,13 +130,14 @@ cc_library(
name = "hashtable_op_kernels",
srcs = [
"hashtable.cc",
"hashtable_find.cc",
"hashtable_import.cc",
"hashtable_lookup.cc",
"hashtable_size.cc",
],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/api",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros",

View File

@ -15,7 +15,9 @@ limitations under the License.
#include <string>
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
#include "tensorflow/lite/kernels/kernel_util.h"
@ -26,7 +28,10 @@ namespace ops {
namespace custom {
namespace hashtable {
constexpr int kResourceHandleTensor = 0;
static constexpr int kResourceHandleTensor = 0;
static constexpr const char kSharedNameStr[] = "shared_name";
static constexpr const char kKeyDtypeStr[] = "key_dtype";
static constexpr const char kValueDtypeStr[] = "value_dtype";
// TODO(b/144728911): The following structure should be moved to
// builtin_op_data.h when it is ready to become a builtin op.
@ -41,11 +46,18 @@ void* InitHashtable(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
const std::string table_name = m[kSharedNameStr].AsString().str();
TfLiteType key_dtype, value_dtype;
ConvertTensorType(static_cast<TensorType>(m[kKeyDtypeStr].AsInt32()),
&key_dtype, nullptr);
ConvertTensorType(static_cast<TensorType>(m[kValueDtypeStr].AsInt32()),
&value_dtype, nullptr);
TfLiteHashtableParams* option = new TfLiteHashtableParams;
option->table_name = m["table_name"].AsString().str();
option->key_dtype = static_cast<TfLiteType>(m["key_dtype"].AsInt32());
option->value_dtype = static_cast<TfLiteType>(m["value_dtype"].AsInt32());
option->table_name = table_name;
option->key_dtype = key_dtype;
option->value_dtype = value_dtype;
return option;
}
@ -61,12 +73,12 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->user_data != nullptr);
const auto* params =
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
TF_LITE_ENSURE(context, !params->table_name.empty());
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt32 ||
params->key_dtype == kTfLiteString));
TF_LITE_ENSURE(context, (params->value_dtype == kTfLiteInt32 ||
params->value_dtype == kTfLiteString ||
params->value_dtype == kTfLiteFloat32));
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt64 &&
params->value_dtype == kTfLiteString) ||
(params->key_dtype == kTfLiteString &&
params->value_dtype == kTfLiteInt64));
TfLiteTensor* resource_handle_tensor =
GetOutput(context, node, kResourceHandleTensor);
@ -78,6 +90,7 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->user_data != nullptr);
const auto* params =
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
@ -100,12 +113,9 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
} // namespace hashtable
TfLiteRegistration* Register_HASHTABLE() {
static TfLiteRegistration r = {hashtable::InitHashtable,
hashtable::FreeHashtable,
hashtable::PrepareHashtable,
hashtable::EvalHashtable,
nullptr,
BuiltinOperator_CUSTOM};
static TfLiteRegistration r = {
hashtable::InitHashtable, hashtable::FreeHashtable,
hashtable::PrepareHashtable, hashtable::EvalHashtable};
return &r;
}

View File

@ -30,7 +30,7 @@ constexpr int kKeyTensor = 1;
constexpr int kDefaultValueTensor = 2;
constexpr int kOutputTensor = 0;
TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@ -42,26 +42,19 @@ TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* default_value_tensor =
GetInput(context, node, kDefaultValueTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(default_value_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(default_value_tensor, 0), 1);
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
TF_LITE_ENSURE(context, (output_tensor->type == kTfLiteInt32 ||
output_tensor->type == kTfLiteString ||
output_tensor->type == kTfLiteFloat32));
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
key_tensor->type == kTfLiteString));
if (output_tensor->type != kTfLiteString) {
return context->ResizeTensor(context, output_tensor,
TfLiteIntArrayCopy(key_tensor->dims));
}
return kTfLiteOk;
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
output_tensor->type == kTfLiteString) ||
(key_tensor->type == kTfLiteString &&
output_tensor->type == kTfLiteInt64));
return context->ResizeTensor(context, output_tensor,
TfLiteIntArrayCopy(key_tensor->dims));
}
TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus EvalHashtableFind(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
int resource_id = input_resource_id_tensor->data.i32[0];
@ -77,19 +70,18 @@ TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, lookup != nullptr);
TF_LITE_ENSURE_STATUS(
lookup->CheckKeyAndValueTypes(context, key_tensor, output_tensor));
return lookup->Lookup(context, key_tensor, output_tensor,
default_value_tensor);
auto result =
lookup->Lookup(context, key_tensor, output_tensor, default_value_tensor);
return result;
}
} // namespace hashtable
TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
TfLiteRegistration* Register_HASHTABLE_FIND() {
static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr,
hashtable::PrepareHashtableLookup,
hashtable::EvalHashtableLookup,
nullptr,
BuiltinOperator_CUSTOM};
hashtable::PrepareHashtableFind,
hashtable::EvalHashtableFind};
return &r;
}

View File

@ -40,13 +40,11 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
key_tensor->type == kTfLiteString));
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
TF_LITE_ENSURE(context, (value_tensor->type == kTfLiteInt32 ||
value_tensor->type == kTfLiteString ||
value_tensor->type == kTfLiteFloat32));
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
value_tensor->type == kTfLiteString) ||
(key_tensor->type == kTfLiteString &&
value_tensor->type == kTfLiteInt64));
// TODO(b/144731295): Tensorflow lookup ops support 1-D vector in storing
// values.
TF_LITE_ENSURE(context, HaveSameShapes(key_tensor, value_tensor));
@ -69,7 +67,8 @@ TfLiteStatus EvalHashtableImport(TfLiteContext* context, TfLiteNode* node) {
lookup->CheckKeyAndValueTypes(context, key_tensor, value_tensor));
// The hashtable resource will only be initialized once, attempting to
// initialize it multiple times will be a no-op.
return lookup->Import(context, key_tensor, value_tensor);
auto result = lookup->Import(context, key_tensor, value_tensor);
return result;
}
} // namespace hashtable
@ -78,9 +77,7 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT() {
static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr,
hashtable::PrepareHashtableImport,
hashtable::EvalHashtableImport,
nullptr,
BuiltinOperator_CUSTOM};
hashtable::EvalHashtableImport};
return &r;
}

View File

@ -34,7 +34,7 @@ namespace ops {
namespace custom {
TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_LOOKUP();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
@ -45,6 +45,10 @@ namespace {
using ::testing::ElementsAreArray;
static constexpr const char kSharedNameStr[] = "shared_name";
static constexpr const char kKeyDtypeStr[] = "key_dtype";
static constexpr const char kValueDtypeStr[] = "value_dtype";
typedef enum {
kResourceTensorId = 0,
kKeyTensorId = 1,
@ -84,6 +88,19 @@ void SetTensorData(Interpreter* interpreter, int tensorId,
buf.WriteToTensorAsVector(tensor);
}
TensorType ConvertTfLiteType(TfLiteType type) {
// Currently, hashtable kernels support INT64 and STRING types only.
switch (type) {
case kTfLiteInt64:
return TensorType_INT64;
case kTfLiteString:
return TensorType_STRING;
default:
CHECK(false); // Not reached.
return TensorType_MIN;
}
}
// HashtableGraph generates a graph with hash table ops. This class can create
// the following scenarios:
//
@ -120,7 +137,7 @@ class HashtableGraph {
// Hash table lookup node.
interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index);
// Hash table size node.
@ -142,7 +159,7 @@ class HashtableGraph {
// Hash table lookup node.
interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index);
// Hash table size node.
@ -174,7 +191,7 @@ class HashtableGraph {
// Hash table lookup node.
interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index);
// Hash table size node.
@ -201,7 +218,7 @@ class HashtableGraph {
// Hash table lookup node.
interpreter_->AddNodeWithParameters(
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index);
// Hash table size node.
@ -226,8 +243,8 @@ class HashtableGraph {
// Hash table two lookup node.
interpreter_->AddNodeWithParameters(
{kResourceTwoTensorId, kQueryTwoTensorId, kDefaultValueTwoTensorId},
{kResultTwoTensorId}, nullptr, 0, nullptr,
hashtable_lookup_registration_, &node_index);
{kResultTwoTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
&node_index);
// Hash table two size node.
interpreter_->AddNodeWithParameters(
@ -261,16 +278,16 @@ class HashtableGraph {
default_value_two_ = default_value;
}
int GetTableSize() {
int64_t GetTableSize() {
auto* size_tensor = interpreter_->tensor(kSizeTensorId);
auto size_tensor_shape = GetTensorShape(size_tensor);
return GetTensorData<int>(size_tensor)[0];
return GetTensorData<int64_t>(size_tensor)[0];
}
int GetTableTwoSize() {
int64_t GetTableTwoSize() {
auto* size_tensor = interpreter_->tensor(kSizeTwoTensorId);
auto size_tensor_shape = GetTensorShape(size_tensor);
return GetTensorData<int>(size_tensor)[0];
return GetTensorData<int64_t>(size_tensor)[0];
}
std::vector<ValueType> GetLookupResult() {
@ -363,7 +380,7 @@ class HashtableGraph {
TfLiteQuantization());
// Result tensor for size calculation.
interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt32, "",
interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt64, "",
{1}, TfLiteQuantization());
// Default value tensor for lookup.
@ -396,7 +413,7 @@ class HashtableGraph {
{static_cast<int>(queries_two_.size())}, TfLiteQuantization());
// Result tensor for size calculation.
interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt32,
interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt64,
"", {1}, TfLiteQuantization());
// Default value tensor for lookup.
@ -433,9 +450,9 @@ class HashtableGraph {
hashtable_registration_ = tflite::ops::custom::Register_HASHTABLE();
ASSERT_NE(hashtable_registration_, nullptr);
hashtable_lookup_registration_ =
tflite::ops::custom::Register_HASHTABLE_LOOKUP();
ASSERT_NE(hashtable_lookup_registration_, nullptr);
hashtable_find_registration_ =
tflite::ops::custom::Register_HASHTABLE_FIND();
ASSERT_NE(hashtable_find_registration_, nullptr);
hashtable_import_registration_ =
tflite::ops::custom::Register_HASHTABLE_IMPORT();
@ -447,11 +464,15 @@ class HashtableGraph {
}
std::vector<uint8_t> GetHashtableParamsInFlatbuffer() {
TensorType key_tensor_type = ConvertTfLiteType(key_type_);
TensorType value_tensor_type = ConvertTfLiteType(value_type_);
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.String("table_name", "test_table_name" + std::to_string(std::rand()));
fbb.Int("key_dtype", key_type_);
fbb.Int("value_dtype", value_type_);
fbb.String(kSharedNameStr,
"test_table_name" + std::to_string(std::rand()));
fbb.Int(kKeyDtypeStr, key_tensor_type);
fbb.Int(kValueDtypeStr, value_tensor_type);
});
fbb.Finish();
return fbb.GetBuffer();
@ -475,7 +496,7 @@ class HashtableGraph {
// Op registrations.
TfLiteRegistration* hashtable_registration_;
TfLiteRegistration* hashtable_lookup_registration_;
TfLiteRegistration* hashtable_find_registration_;
TfLiteRegistration* hashtable_import_registration_;
TfLiteRegistration* hashtable_size_registration_;
@ -539,64 +560,27 @@ class HashtableDefaultGraphTest {
std::vector<ValueType> lookup_result_;
};
TEST(HashtableOpsTest, TestInt32ToInt32Hashtable) {
HashtableDefaultGraphTest<int, int> t(
kTfLiteInt32, kTfLiteInt32,
/*keys=*/{1, 2, 3}, /*values=*/{4, 5, 6}, /*queries=*/{2, 3, 4},
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
t.InvokeAndVerifyIntResult();
}
TEST(HashtableOpsTest, TestInt32ToFloat32Hashtable) {
HashtableDefaultGraphTest<int, float> t(
kTfLiteInt32, kTfLiteFloat32,
/*keys=*/{1, 2, 3}, /*values=*/{4.0f, 5.0f, 6.0f}, /*queries=*/{2, 3, 4},
/*default_value=*/-1.0f, /*table_size=*/3,
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
t.InvokeAndVerifyFloatResult();
}
TEST(HashtableOpsTest, TestInt32ToStringHashtable) {
HashtableDefaultGraphTest<int, std::string> t(
kTfLiteInt32, kTfLiteString,
TEST(HashtableOpsTest, TestInt64ToStringHashtable) {
HashtableDefaultGraphTest<std::int64_t, std::string> t(
kTfLiteInt64, kTfLiteString,
/*keys=*/{1, 2, 3}, /*values=*/{"a", "b", "c"}, /*queries=*/{2, 3, 4},
/*default_value=*/"d", /*table_size=*/3,
/*lookup_result=*/{"b", "c", "d"});
t.InvokeAndVerifyStringResult();
}
TEST(HashtableOpsTest, TestStringToInt32Hashtable) {
HashtableDefaultGraphTest<std::string, int> t(
kTfLiteString, kTfLiteInt32,
TEST(HashtableOpsTest, TestStringToInt64Hashtable) {
HashtableDefaultGraphTest<std::string, int64_t> t(
kTfLiteString, kTfLiteInt64,
/*keys=*/{"A", "B", "C"}, /*values=*/{4, 5, 6},
/*queries=*/{"B", "C", "D"},
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
t.InvokeAndVerifyIntResult();
}
TEST(HashtableOpsTest, TestStringToFloat32Hashtable) {
HashtableDefaultGraphTest<std::string, float> t(
kTfLiteString, kTfLiteFloat32,
/*keys=*/{"A", "B", "C"}, /*values=*/{4.0f, 5.0f, 6.0f},
/*queries=*/{"B", "C", "D"},
/*default_value=*/-1.0f, /*table_size=*/3,
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
t.InvokeAndVerifyFloatResult();
}
TEST(HashtableOpsTest, TestStringToStringHashtable) {
HashtableDefaultGraphTest<std::string, std::string> t(
kTfLiteString, kTfLiteString,
/*keys=*/{"A", "B", "C"}, /*values=*/{"a", "b", "c"},
/*queries=*/{"B", "C", "D"},
/*default_value=*/"d", /*table_size=*/3,
/*lookup_result=*/{"b", "c", "d"});
t.InvokeAndVerifyStringResult();
}
TEST(HashtableOpsTest, TestNoImport) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
graph.SetQuery({1, 2, 3}, -1);
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetQuery({"1", "2", "3"}, -1);
graph.AddTensors();
graph.BuildNoImportGraph();
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
@ -607,9 +591,9 @@ TEST(HashtableOpsTest, TestNoImport) {
}
TEST(HashtableOpsTest, TestImportTwice) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
graph.SetTable({1, 2, 3}, {4, 5, 6});
graph.SetQuery({2, 3, 4}, -1);
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetTable({"1", "2", "3"}, {4, 5, 6});
graph.SetQuery({"2", "3", "4"}, -1);
graph.AddTensors();
graph.BuildImportTwiceGraph();
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
@ -621,11 +605,11 @@ TEST(HashtableOpsTest, TestImportTwice) {
}
TEST(HashtableOpsTest, TestTwoHashtables) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
graph.SetTable({1, 2, 3}, {4, 5, 6});
graph.SetQuery({2, 3, 4}, -1);
graph.SetTableTwo({-1, -2, -3}, {7, 8, 9});
graph.SetQueryForTableTwo({-4, -2, -3}, -2);
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetTable({"1", "2", "3"}, {4, 5, 6});
graph.SetQuery({"2", "3", "4"}, -1);
graph.SetTableTwo({"-1", "-2", "-3"}, {7, 8, 9});
graph.SetQueryForTableTwo({"-4", "-2", "-3"}, -2);
graph.AddTensors(/*table_two_initialization=*/true);
graph.BuildTwoHashtablesGraph();
EXPECT_EQ(graph.AllocateTensors(/*table_two_initialization=*/true),
@ -639,9 +623,9 @@ TEST(HashtableOpsTest, TestTwoHashtables) {
}
TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
graph.SetTable({1, 2, 3}, {4, 5});
graph.SetQuery({2, 3, 4}, -1);
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
graph.SetTable({"1", "2", "3"}, {4, 5});
graph.SetQuery({"2", "3", "4"}, -1);
graph.AddTensors();
graph.BuildDefaultGraph();
EXPECT_EQ(graph.AllocateTensors(), kTfLiteError);
@ -650,16 +634,16 @@ TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
// HashtableOpModel creates a model with one signle Hashtable op.
class HashtableOpModel : public SingleOpModel {
public:
explicit HashtableOpModel(const char* table_name, TfLiteType key_dtype,
TfLiteType value_dtype) {
explicit HashtableOpModel(const char* table_name, TensorType key_dtype,
TensorType value_dtype) {
output_ = AddOutput(GetTensorType<int>());
// Set up and pass in custom options using flexbuffer.
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.String("table_name", std::string(table_name));
fbb.Int("key_dtype", key_dtype);
fbb.Int("value_dtype", value_dtype);
fbb.String(kSharedNameStr, std::string(table_name));
fbb.Int(kKeyDtypeStr, key_dtype);
fbb.Int(kValueDtypeStr, value_dtype);
});
fbb.Finish();
SetCustomOp("HASHTABLE", fbb.GetBuffer(),
@ -679,7 +663,7 @@ class HashtableOpModel : public SingleOpModel {
};
TEST(HashtableOpsTest, TestHashtable) {
HashtableOpModel m("test_hashtable", kTfLiteInt32, kTfLiteString);
HashtableOpModel m("test_hashtable", TensorType_INT64, TensorType_STRING);
EXPECT_EQ(m.GetResources().size(), 0);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
@ -689,12 +673,12 @@ TEST(HashtableOpsTest, TestHashtable) {
EXPECT_NE(resource_id, 0);
auto* hashtable = resource::GetHashtableResource(&resources, resource_id);
EXPECT_TRUE(hashtable != nullptr);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
}
template <typename T>
TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) {
TfLiteTensor CreateTensor(TfLiteType type, const std::vector<T>& vec) {
TfLiteTensor tensor = {};
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
dims->data[0] = vec.size();
@ -715,6 +699,28 @@ TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) {
return tensor;
}
template <>
TfLiteTensor CreateTensor(TfLiteType type,
const std::vector<std::string>& vec) {
TfLiteTensor tensor = {};
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
dims->data[0] = vec.size();
tensor.dims = dims;
tensor.name = "";
tensor.params = {};
tensor.quantization = {kTfLiteNoQuantization, nullptr};
tensor.is_variable = false;
tensor.allocation_type = kTfLiteDynamic;
tensor.allocation = nullptr;
tensor.type = type;
DynamicBuffer buf;
for (std::string str : vec) {
buf.AddString(str.c_str(), str.size());
}
buf.WriteToTensor(&tensor, nullptr);
return tensor;
}
template <typename KeyType, typename ValueType>
void InitHashtableResource(resource::ResourceMap* resources, int resource_id,
TfLiteType key_type, TfLiteType value_type,
@ -772,12 +778,12 @@ class BaseHashtableOpModel : public SingleOpModel {
TensorType value_type_;
};
// HashtableLookupOpModel creates a model with a HashtableLookup op.
// HashtableFindOpModel creates a model with a HashtableLookup op.
template <typename KeyType, typename ValueType>
class HashtableLookupOpModel : public BaseHashtableOpModel {
class HashtableFindOpModel : public BaseHashtableOpModel {
public:
HashtableLookupOpModel(const TensorType key_type, const TensorType value_type,
int lookup_size) {
HashtableFindOpModel(const TensorType key_type, const TensorType value_type,
int lookup_size) {
key_type_ = key_type;
value_type_ = value_type;
@ -787,8 +793,8 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
output_ = AddOutput({value_type, {lookup_size}});
SetCustomOp("HASHTABLE_LOOKUP", {},
tflite::ops::custom::Register_HASHTABLE_LOOKUP);
SetCustomOp("HASHTABLE_FIND", {},
tflite::ops::custom::Register_HASHTABLE_FIND);
BuildInterpreter(
{GetShape(resource_id_), GetShape(lookup_), GetShape(default_value_)});
}
@ -797,46 +803,56 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
PopulateTensor(lookup_, data);
}
void SetStringLookup(const std::vector<std::string>& data) {
PopulateStringTensor(lookup_, data);
}
void SetDefaultValue(const std::vector<ValueType>& data) {
PopulateTensor(default_value_, data);
}
void SetStringDefaultValue(const std::vector<std::string>& data) {
PopulateStringTensor(default_value_, data);
}
private:
int lookup_;
int default_value_;
};
TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
TEST(HashtableOpsTest, TestHashtableLookupStringToInt64) {
const int kResourceId = 42;
HashtableLookupOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
TensorType_INT32, 3);
HashtableFindOpModel<std::string, std::int64_t> m(TensorType_STRING,
TensorType_INT64, 3);
m.SetResourceId({kResourceId});
m.SetLookup({5, 6, 7});
m.SetStringLookup({"5", "6", "7"});
m.SetDefaultValue({4});
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
InitHashtableResource<std::string, std::int64_t>(
&m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
{"4", "5", "6"}, {1, 2, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({2, 3, 4}));
EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({2, 3, 4}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
}
TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) {
TEST(HashtableOpsTest, TestHashtableLookupInt64ToString) {
const int kResourceId = 42;
HashtableLookupOpModel<std::int32_t, float> m(TensorType_INT32,
TensorType_FLOAT32, 3);
HashtableFindOpModel<std::int64_t, std::string> m(TensorType_INT64,
TensorType_STRING, 3);
m.SetResourceId({kResourceId});
m.SetLookup({5, 6, 7});
m.SetDefaultValue({4.0f});
m.SetStringDefaultValue({"4"});
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
kTfLiteFloat32, {4, 5, 6}, {1.0f, 2.0f, 3.0f});
InitHashtableResource<std::int64_t, std::string>(
&m.GetResources(), kResourceId, kTfLiteInt64, kTfLiteString, {4, 5, 6},
{"1", "2", "3"});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({2.0f, 3.0f, 4.0f}));
EXPECT_THAT(m.GetOutput<std::string>(), ElementsAreArray({"2", "3", "4"}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
}
@ -863,19 +879,27 @@ class HashtableImportOpModel : public BaseHashtableOpModel {
PopulateTensor(keys_, data);
}
void SetStringKeys(const std::vector<std::string>& data) {
PopulateStringTensor(keys_, data);
}
void SetValues(const std::vector<ValueType>& data) {
PopulateTensor(values_, data);
}
void SetStringValues(const std::vector<std::string>& data) {
PopulateStringTensor(values_, data);
}
};
TEST(HashtableOpsTest, TestHashtableImport) {
const int kResourceId = 42;
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
TensorType_FLOAT32, 3);
HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
TensorType_STRING, 3);
EXPECT_EQ(m.GetResources().size(), 0);
m.SetResourceId({kResourceId});
m.SetKeys({1, 2, 3});
m.SetValues({1.0f, 2.0f, 3.0f});
m.SetStringValues({"1", "2", "3"});
m.CreateHashtableResource(kResourceId);
m.Invoke();
@ -883,20 +907,20 @@ TEST(HashtableOpsTest, TestHashtableImport) {
EXPECT_EQ(resources.size(), 1);
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
EXPECT_TRUE(hashtable != nullptr);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
EXPECT_EQ(hashtable->Size(), 3);
}
TEST(HashtableOpsTest, TestHashtableImportTwice) {
const int kResourceId = 42;
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
TensorType_FLOAT32, 3);
HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
TensorType_STRING, 3);
EXPECT_EQ(m.GetResources().size(), 0);
m.SetResourceId({kResourceId});
m.SetKeys({1, 2, 3});
m.SetValues({1.0f, 2.0f, 3.0f});
m.SetStringValues({"1", "2", "3"});
m.CreateHashtableResource(kResourceId);
m.Invoke();
m.Invoke();
@ -905,8 +929,8 @@ TEST(HashtableOpsTest, TestHashtableImportTwice) {
EXPECT_EQ(resources.size(), 1);
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
EXPECT_TRUE(hashtable != nullptr);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32);
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
EXPECT_EQ(hashtable->Size(), 3);
}
@ -920,7 +944,7 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
resource_id_ = AddInput({TensorType_INT32, {1}});
output_ = AddOutput({TensorType_INT32, {1}});
output_ = AddOutput({TensorType_INT64, {1}});
SetCustomOp("HASHTABLE_SIZE", {},
tflite::ops::custom::Register_HASHTABLE_SIZE);
@ -930,23 +954,24 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
TEST(HashtableOpsTest, TestHashtableSize) {
const int kResourceId = 42;
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
TensorType_INT32);
HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
TensorType_INT64);
m.SetResourceId({kResourceId});
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
InitHashtableResource<std::string, std::int64_t>(
&m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
{"4", "5", "6"}, {1, 2, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({3}));
EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({3}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
}
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
const int kResourceId = 42;
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
TensorType_INT32);
HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
TensorType_INT64);
m.SetResourceId({kResourceId});
// Invoke without hash table initialization.

View File

@ -40,7 +40,7 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output_tensor != nullptr);
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt64);
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
outputSize->data[0] = 1;
return context->ResizeTensor(context, output_tensor, outputSize);
@ -52,7 +52,7 @@ TfLiteStatus EvalHashtableSize(TfLiteContext* context, TfLiteNode* node) {
int resource_id = input_resource_id_tensor->data.i32[0];
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
auto* output_data = GetTensorData<int>(output_tensor);
auto* output_data = GetTensorData<std::int64_t>(output_tensor);
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources();
@ -69,9 +69,7 @@ TfLiteRegistration* Register_HASHTABLE_SIZE() {
static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr,
hashtable::PrepareHashtableSize,
hashtable::EvalHashtableSize,
nullptr,
BuiltinOperator_CUSTOM};
hashtable::EvalHashtableSize};
return &r;
}

View File

@ -80,33 +80,14 @@ TfLiteStatus StaticHashtable<KeyType, ValueType>::Import(
return kTfLiteOk;
}
template <typename KeyType>
LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
TfLiteType value_type) {
switch (value_type) {
case kTfLiteInt32:
return new StaticHashtable<KeyType, std::int32_t>(key_type, value_type);
case kTfLiteString:
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
case kTfLiteFloat32:
return new StaticHashtable<KeyType, float>(key_type, value_type);
default:
return nullptr;
}
}
LookupInterface* CreateStaticHashtable(TfLiteType key_type,
TfLiteType value_type) {
switch (key_type) {
case kTfLiteInt32:
return CreateStaticHashtableWithGivenKey<std::int32_t>(key_type,
value_type);
case kTfLiteString:
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
value_type);
default:
return nullptr;
if (key_type == kTfLiteInt64 && value_type == kTfLiteString) {
return new StaticHashtable<std::int64_t, std::string>(key_type, value_type);
} else if (key_type == kTfLiteString && value_type == kTfLiteInt64) {
return new StaticHashtable<std::string, std::int64_t>(key_type, value_type);
}
return nullptr;
}
} // namespace internal

View File

@ -588,6 +588,7 @@ cc_library(
":op_macros",
"//tensorflow/lite:context",
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/kernels:hashtable_op_kernels",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:tensor",
"//third_party/fft2d:fft2d_headers",

View File

@ -22,7 +22,10 @@ namespace ops {
namespace custom {
TfLiteRegistration* Register_RFFT2D();
TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
}
} // namespace ops
} // namespace tflite

View File

@ -322,6 +322,15 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
buildinop_resolver_->AddCustom("RFFT2D",
tflite::ops::custom::Register_RFFT2D());
buildinop_resolver_->AddCustom("HashTableV2",
tflite::ops::custom::Register_HASHTABLE());
buildinop_resolver_->AddCustom(
"LookupTableFindV2", tflite::ops::custom::Register_HASHTABLE_FIND());
buildinop_resolver_->AddCustom(
"LookupTableImportV2",
tflite::ops::custom::Register_HASHTABLE_IMPORT());
buildinop_resolver_->AddCustom(
"LookupTableSizeV2", tflite::ops::custom::Register_HASHTABLE_SIZE());
}
switch (delegate_type) {