Make Hashtable ops in TFLite compatible with MLIR converter
PiperOrigin-RevId: 288628649 Change-Id: Ie7dbd06e1e277e33da5d08993d27d08920837f3a
This commit is contained in:
parent
7c965ff357
commit
57b4fbbfdb
@ -130,13 +130,14 @@ cc_library(
|
|||||||
name = "hashtable_op_kernels",
|
name = "hashtable_op_kernels",
|
||||||
srcs = [
|
srcs = [
|
||||||
"hashtable.cc",
|
"hashtable.cc",
|
||||||
|
"hashtable_find.cc",
|
||||||
"hashtable_import.cc",
|
"hashtable_import.cc",
|
||||||
"hashtable_lookup.cc",
|
|
||||||
"hashtable_size.cc",
|
"hashtable_size.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
|
"//tensorflow/lite/core/api",
|
||||||
"//tensorflow/lite/experimental/resource",
|
"//tensorflow/lite/experimental/resource",
|
||||||
"//tensorflow/lite/kernels:kernel_util",
|
"//tensorflow/lite/kernels:kernel_util",
|
||||||
"//tensorflow/lite/kernels:op_macros",
|
"//tensorflow/lite/kernels:op_macros",
|
||||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||||
#include "tensorflow/lite/core/subgraph.h"
|
#include "tensorflow/lite/core/subgraph.h"
|
||||||
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
|
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
@ -26,7 +28,10 @@ namespace ops {
|
|||||||
namespace custom {
|
namespace custom {
|
||||||
namespace hashtable {
|
namespace hashtable {
|
||||||
|
|
||||||
constexpr int kResourceHandleTensor = 0;
|
static constexpr int kResourceHandleTensor = 0;
|
||||||
|
static constexpr const char kSharedNameStr[] = "shared_name";
|
||||||
|
static constexpr const char kKeyDtypeStr[] = "key_dtype";
|
||||||
|
static constexpr const char kValueDtypeStr[] = "value_dtype";
|
||||||
|
|
||||||
// TODO(b/144728911): The following structure should be moved to
|
// TODO(b/144728911): The following structure should be moved to
|
||||||
// builtin_op_data.h when it is ready to become a builtin op.
|
// builtin_op_data.h when it is ready to become a builtin op.
|
||||||
@ -41,11 +46,18 @@ void* InitHashtable(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
|
|
||||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||||
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
||||||
|
const std::string table_name = m[kSharedNameStr].AsString().str();
|
||||||
|
|
||||||
|
TfLiteType key_dtype, value_dtype;
|
||||||
|
ConvertTensorType(static_cast<TensorType>(m[kKeyDtypeStr].AsInt32()),
|
||||||
|
&key_dtype, nullptr);
|
||||||
|
ConvertTensorType(static_cast<TensorType>(m[kValueDtypeStr].AsInt32()),
|
||||||
|
&value_dtype, nullptr);
|
||||||
|
|
||||||
TfLiteHashtableParams* option = new TfLiteHashtableParams;
|
TfLiteHashtableParams* option = new TfLiteHashtableParams;
|
||||||
option->table_name = m["table_name"].AsString().str();
|
option->table_name = table_name;
|
||||||
option->key_dtype = static_cast<TfLiteType>(m["key_dtype"].AsInt32());
|
option->key_dtype = key_dtype;
|
||||||
option->value_dtype = static_cast<TfLiteType>(m["value_dtype"].AsInt32());
|
option->value_dtype = value_dtype;
|
||||||
|
|
||||||
return option;
|
return option;
|
||||||
}
|
}
|
||||||
@ -61,12 +73,12 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||||
const auto* params =
|
const auto* params =
|
||||||
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
|
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
|
||||||
|
|
||||||
TF_LITE_ENSURE(context, !params->table_name.empty());
|
TF_LITE_ENSURE(context, !params->table_name.empty());
|
||||||
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt32 ||
|
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt64 &&
|
||||||
params->key_dtype == kTfLiteString));
|
params->value_dtype == kTfLiteString) ||
|
||||||
TF_LITE_ENSURE(context, (params->value_dtype == kTfLiteInt32 ||
|
(params->key_dtype == kTfLiteString &&
|
||||||
params->value_dtype == kTfLiteString ||
|
params->value_dtype == kTfLiteInt64));
|
||||||
params->value_dtype == kTfLiteFloat32));
|
|
||||||
|
|
||||||
TfLiteTensor* resource_handle_tensor =
|
TfLiteTensor* resource_handle_tensor =
|
||||||
GetOutput(context, node, kResourceHandleTensor);
|
GetOutput(context, node, kResourceHandleTensor);
|
||||||
@ -78,6 +90,7 @@ TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE(context, node->user_data != nullptr);
|
||||||
const auto* params =
|
const auto* params =
|
||||||
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
|
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
|
||||||
|
|
||||||
@ -100,12 +113,9 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace hashtable
|
} // namespace hashtable
|
||||||
|
|
||||||
TfLiteRegistration* Register_HASHTABLE() {
|
TfLiteRegistration* Register_HASHTABLE() {
|
||||||
static TfLiteRegistration r = {hashtable::InitHashtable,
|
static TfLiteRegistration r = {
|
||||||
hashtable::FreeHashtable,
|
hashtable::InitHashtable, hashtable::FreeHashtable,
|
||||||
hashtable::PrepareHashtable,
|
hashtable::PrepareHashtable, hashtable::EvalHashtable};
|
||||||
hashtable::EvalHashtable,
|
|
||||||
nullptr,
|
|
||||||
BuiltinOperator_CUSTOM};
|
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ constexpr int kKeyTensor = 1;
|
|||||||
constexpr int kDefaultValueTensor = 2;
|
constexpr int kDefaultValueTensor = 2;
|
||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus PrepareHashtableFind(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
@ -42,26 +42,19 @@ TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
const TfLiteTensor* default_value_tensor =
|
const TfLiteTensor* default_value_tensor =
|
||||||
GetInput(context, node, kDefaultValueTensor);
|
GetInput(context, node, kDefaultValueTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(default_value_tensor), 1);
|
|
||||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(default_value_tensor, 0), 1);
|
|
||||||
|
|
||||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
|
||||||
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
|
|
||||||
TF_LITE_ENSURE(context, (output_tensor->type == kTfLiteInt32 ||
|
|
||||||
output_tensor->type == kTfLiteString ||
|
|
||||||
output_tensor->type == kTfLiteFloat32));
|
|
||||||
|
|
||||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
|
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||||
key_tensor->type == kTfLiteString));
|
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
|
||||||
if (output_tensor->type != kTfLiteString) {
|
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
|
||||||
return context->ResizeTensor(context, output_tensor,
|
output_tensor->type == kTfLiteString) ||
|
||||||
TfLiteIntArrayCopy(key_tensor->dims));
|
(key_tensor->type == kTfLiteString &&
|
||||||
}
|
output_tensor->type == kTfLiteInt64));
|
||||||
return kTfLiteOk;
|
return context->ResizeTensor(context, output_tensor,
|
||||||
|
TfLiteIntArrayCopy(key_tensor->dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus EvalHashtableFind(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLiteTensor* input_resource_id_tensor =
|
const TfLiteTensor* input_resource_id_tensor =
|
||||||
GetInput(context, node, kInputResourceIdTensor);
|
GetInput(context, node, kInputResourceIdTensor);
|
||||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||||
@ -77,19 +70,18 @@ TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, lookup != nullptr);
|
TF_LITE_ENSURE(context, lookup != nullptr);
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
lookup->CheckKeyAndValueTypes(context, key_tensor, output_tensor));
|
lookup->CheckKeyAndValueTypes(context, key_tensor, output_tensor));
|
||||||
return lookup->Lookup(context, key_tensor, output_tensor,
|
auto result =
|
||||||
default_value_tensor);
|
lookup->Lookup(context, key_tensor, output_tensor, default_value_tensor);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace hashtable
|
} // namespace hashtable
|
||||||
|
|
||||||
TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
|
TfLiteRegistration* Register_HASHTABLE_FIND() {
|
||||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
hashtable::PrepareHashtableLookup,
|
hashtable::PrepareHashtableFind,
|
||||||
hashtable::EvalHashtableLookup,
|
hashtable::EvalHashtableFind};
|
||||||
nullptr,
|
|
||||||
BuiltinOperator_CUSTOM};
|
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
@ -40,13 +40,11 @@ TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
|
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);
|
||||||
|
|
||||||
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
|
||||||
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
|
|
||||||
key_tensor->type == kTfLiteString));
|
|
||||||
|
|
||||||
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
|
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
|
||||||
TF_LITE_ENSURE(context, (value_tensor->type == kTfLiteInt32 ||
|
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt64 &&
|
||||||
value_tensor->type == kTfLiteString ||
|
value_tensor->type == kTfLiteString) ||
|
||||||
value_tensor->type == kTfLiteFloat32));
|
(key_tensor->type == kTfLiteString &&
|
||||||
|
value_tensor->type == kTfLiteInt64));
|
||||||
// TODO(b/144731295): Tensorflow lookup ops support 1-D vector in storing
|
// TODO(b/144731295): Tensorflow lookup ops support 1-D vector in storing
|
||||||
// values.
|
// values.
|
||||||
TF_LITE_ENSURE(context, HaveSameShapes(key_tensor, value_tensor));
|
TF_LITE_ENSURE(context, HaveSameShapes(key_tensor, value_tensor));
|
||||||
@ -69,7 +67,8 @@ TfLiteStatus EvalHashtableImport(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
lookup->CheckKeyAndValueTypes(context, key_tensor, value_tensor));
|
lookup->CheckKeyAndValueTypes(context, key_tensor, value_tensor));
|
||||||
// The hashtable resource will only be initialized once, attempting to
|
// The hashtable resource will only be initialized once, attempting to
|
||||||
// initialize it multiple times will be a no-op.
|
// initialize it multiple times will be a no-op.
|
||||||
return lookup->Import(context, key_tensor, value_tensor);
|
auto result = lookup->Import(context, key_tensor, value_tensor);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace hashtable
|
} // namespace hashtable
|
||||||
@ -78,9 +77,7 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT() {
|
|||||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
hashtable::PrepareHashtableImport,
|
hashtable::PrepareHashtableImport,
|
||||||
hashtable::EvalHashtableImport,
|
hashtable::EvalHashtableImport};
|
||||||
nullptr,
|
|
||||||
BuiltinOperator_CUSTOM};
|
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ namespace ops {
|
|||||||
namespace custom {
|
namespace custom {
|
||||||
|
|
||||||
TfLiteRegistration* Register_HASHTABLE();
|
TfLiteRegistration* Register_HASHTABLE();
|
||||||
TfLiteRegistration* Register_HASHTABLE_LOOKUP();
|
TfLiteRegistration* Register_HASHTABLE_FIND();
|
||||||
TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
||||||
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
||||||
|
|
||||||
@ -45,6 +45,10 @@ namespace {
|
|||||||
|
|
||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
static constexpr const char kSharedNameStr[] = "shared_name";
|
||||||
|
static constexpr const char kKeyDtypeStr[] = "key_dtype";
|
||||||
|
static constexpr const char kValueDtypeStr[] = "value_dtype";
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
kResourceTensorId = 0,
|
kResourceTensorId = 0,
|
||||||
kKeyTensorId = 1,
|
kKeyTensorId = 1,
|
||||||
@ -84,6 +88,19 @@ void SetTensorData(Interpreter* interpreter, int tensorId,
|
|||||||
buf.WriteToTensorAsVector(tensor);
|
buf.WriteToTensorAsVector(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TensorType ConvertTfLiteType(TfLiteType type) {
|
||||||
|
// Currently, hashtable kernels support INT64 and STRING types only.
|
||||||
|
switch (type) {
|
||||||
|
case kTfLiteInt64:
|
||||||
|
return TensorType_INT64;
|
||||||
|
case kTfLiteString:
|
||||||
|
return TensorType_STRING;
|
||||||
|
default:
|
||||||
|
CHECK(false); // Not reached.
|
||||||
|
return TensorType_MIN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// HashtableGraph generates a graph with hash table ops. This class can create
|
// HashtableGraph generates a graph with hash table ops. This class can create
|
||||||
// the following scenarios:
|
// the following scenarios:
|
||||||
//
|
//
|
||||||
@ -120,7 +137,7 @@ class HashtableGraph {
|
|||||||
// Hash table lookup node.
|
// Hash table lookup node.
|
||||||
interpreter_->AddNodeWithParameters(
|
interpreter_->AddNodeWithParameters(
|
||||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||||
&node_index);
|
&node_index);
|
||||||
|
|
||||||
// Hash table size node.
|
// Hash table size node.
|
||||||
@ -142,7 +159,7 @@ class HashtableGraph {
|
|||||||
// Hash table lookup node.
|
// Hash table lookup node.
|
||||||
interpreter_->AddNodeWithParameters(
|
interpreter_->AddNodeWithParameters(
|
||||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||||
&node_index);
|
&node_index);
|
||||||
|
|
||||||
// Hash table size node.
|
// Hash table size node.
|
||||||
@ -174,7 +191,7 @@ class HashtableGraph {
|
|||||||
// Hash table lookup node.
|
// Hash table lookup node.
|
||||||
interpreter_->AddNodeWithParameters(
|
interpreter_->AddNodeWithParameters(
|
||||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||||
&node_index);
|
&node_index);
|
||||||
|
|
||||||
// Hash table size node.
|
// Hash table size node.
|
||||||
@ -201,7 +218,7 @@ class HashtableGraph {
|
|||||||
// Hash table lookup node.
|
// Hash table lookup node.
|
||||||
interpreter_->AddNodeWithParameters(
|
interpreter_->AddNodeWithParameters(
|
||||||
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
{kResourceTensorId, kQueryTensorId, kDefaultValueTensorId},
|
||||||
{kResultTensorId}, nullptr, 0, nullptr, hashtable_lookup_registration_,
|
{kResultTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||||
&node_index);
|
&node_index);
|
||||||
|
|
||||||
// Hash table size node.
|
// Hash table size node.
|
||||||
@ -226,8 +243,8 @@ class HashtableGraph {
|
|||||||
// Hash table two lookup node.
|
// Hash table two lookup node.
|
||||||
interpreter_->AddNodeWithParameters(
|
interpreter_->AddNodeWithParameters(
|
||||||
{kResourceTwoTensorId, kQueryTwoTensorId, kDefaultValueTwoTensorId},
|
{kResourceTwoTensorId, kQueryTwoTensorId, kDefaultValueTwoTensorId},
|
||||||
{kResultTwoTensorId}, nullptr, 0, nullptr,
|
{kResultTwoTensorId}, nullptr, 0, nullptr, hashtable_find_registration_,
|
||||||
hashtable_lookup_registration_, &node_index);
|
&node_index);
|
||||||
|
|
||||||
// Hash table two size node.
|
// Hash table two size node.
|
||||||
interpreter_->AddNodeWithParameters(
|
interpreter_->AddNodeWithParameters(
|
||||||
@ -261,16 +278,16 @@ class HashtableGraph {
|
|||||||
default_value_two_ = default_value;
|
default_value_two_ = default_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
int GetTableSize() {
|
int64_t GetTableSize() {
|
||||||
auto* size_tensor = interpreter_->tensor(kSizeTensorId);
|
auto* size_tensor = interpreter_->tensor(kSizeTensorId);
|
||||||
auto size_tensor_shape = GetTensorShape(size_tensor);
|
auto size_tensor_shape = GetTensorShape(size_tensor);
|
||||||
return GetTensorData<int>(size_tensor)[0];
|
return GetTensorData<int64_t>(size_tensor)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
int GetTableTwoSize() {
|
int64_t GetTableTwoSize() {
|
||||||
auto* size_tensor = interpreter_->tensor(kSizeTwoTensorId);
|
auto* size_tensor = interpreter_->tensor(kSizeTwoTensorId);
|
||||||
auto size_tensor_shape = GetTensorShape(size_tensor);
|
auto size_tensor_shape = GetTensorShape(size_tensor);
|
||||||
return GetTensorData<int>(size_tensor)[0];
|
return GetTensorData<int64_t>(size_tensor)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ValueType> GetLookupResult() {
|
std::vector<ValueType> GetLookupResult() {
|
||||||
@ -363,7 +380,7 @@ class HashtableGraph {
|
|||||||
TfLiteQuantization());
|
TfLiteQuantization());
|
||||||
|
|
||||||
// Result tensor for size calculation.
|
// Result tensor for size calculation.
|
||||||
interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt32, "",
|
interpreter_->SetTensorParametersReadWrite(kSizeTensorId, kTfLiteInt64, "",
|
||||||
{1}, TfLiteQuantization());
|
{1}, TfLiteQuantization());
|
||||||
|
|
||||||
// Default value tensor for lookup.
|
// Default value tensor for lookup.
|
||||||
@ -396,7 +413,7 @@ class HashtableGraph {
|
|||||||
{static_cast<int>(queries_two_.size())}, TfLiteQuantization());
|
{static_cast<int>(queries_two_.size())}, TfLiteQuantization());
|
||||||
|
|
||||||
// Result tensor for size calculation.
|
// Result tensor for size calculation.
|
||||||
interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt32,
|
interpreter_->SetTensorParametersReadWrite(kSizeTwoTensorId, kTfLiteInt64,
|
||||||
"", {1}, TfLiteQuantization());
|
"", {1}, TfLiteQuantization());
|
||||||
|
|
||||||
// Default value tensor for lookup.
|
// Default value tensor for lookup.
|
||||||
@ -433,9 +450,9 @@ class HashtableGraph {
|
|||||||
hashtable_registration_ = tflite::ops::custom::Register_HASHTABLE();
|
hashtable_registration_ = tflite::ops::custom::Register_HASHTABLE();
|
||||||
ASSERT_NE(hashtable_registration_, nullptr);
|
ASSERT_NE(hashtable_registration_, nullptr);
|
||||||
|
|
||||||
hashtable_lookup_registration_ =
|
hashtable_find_registration_ =
|
||||||
tflite::ops::custom::Register_HASHTABLE_LOOKUP();
|
tflite::ops::custom::Register_HASHTABLE_FIND();
|
||||||
ASSERT_NE(hashtable_lookup_registration_, nullptr);
|
ASSERT_NE(hashtable_find_registration_, nullptr);
|
||||||
|
|
||||||
hashtable_import_registration_ =
|
hashtable_import_registration_ =
|
||||||
tflite::ops::custom::Register_HASHTABLE_IMPORT();
|
tflite::ops::custom::Register_HASHTABLE_IMPORT();
|
||||||
@ -447,11 +464,15 @@ class HashtableGraph {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> GetHashtableParamsInFlatbuffer() {
|
std::vector<uint8_t> GetHashtableParamsInFlatbuffer() {
|
||||||
|
TensorType key_tensor_type = ConvertTfLiteType(key_type_);
|
||||||
|
TensorType value_tensor_type = ConvertTfLiteType(value_type_);
|
||||||
|
|
||||||
flexbuffers::Builder fbb;
|
flexbuffers::Builder fbb;
|
||||||
fbb.Map([&]() {
|
fbb.Map([&]() {
|
||||||
fbb.String("table_name", "test_table_name" + std::to_string(std::rand()));
|
fbb.String(kSharedNameStr,
|
||||||
fbb.Int("key_dtype", key_type_);
|
"test_table_name" + std::to_string(std::rand()));
|
||||||
fbb.Int("value_dtype", value_type_);
|
fbb.Int(kKeyDtypeStr, key_tensor_type);
|
||||||
|
fbb.Int(kValueDtypeStr, value_tensor_type);
|
||||||
});
|
});
|
||||||
fbb.Finish();
|
fbb.Finish();
|
||||||
return fbb.GetBuffer();
|
return fbb.GetBuffer();
|
||||||
@ -475,7 +496,7 @@ class HashtableGraph {
|
|||||||
|
|
||||||
// Op registrations.
|
// Op registrations.
|
||||||
TfLiteRegistration* hashtable_registration_;
|
TfLiteRegistration* hashtable_registration_;
|
||||||
TfLiteRegistration* hashtable_lookup_registration_;
|
TfLiteRegistration* hashtable_find_registration_;
|
||||||
TfLiteRegistration* hashtable_import_registration_;
|
TfLiteRegistration* hashtable_import_registration_;
|
||||||
TfLiteRegistration* hashtable_size_registration_;
|
TfLiteRegistration* hashtable_size_registration_;
|
||||||
|
|
||||||
@ -539,64 +560,27 @@ class HashtableDefaultGraphTest {
|
|||||||
std::vector<ValueType> lookup_result_;
|
std::vector<ValueType> lookup_result_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestInt32ToInt32Hashtable) {
|
TEST(HashtableOpsTest, TestInt64ToStringHashtable) {
|
||||||
HashtableDefaultGraphTest<int, int> t(
|
HashtableDefaultGraphTest<std::int64_t, std::string> t(
|
||||||
kTfLiteInt32, kTfLiteInt32,
|
kTfLiteInt64, kTfLiteString,
|
||||||
/*keys=*/{1, 2, 3}, /*values=*/{4, 5, 6}, /*queries=*/{2, 3, 4},
|
|
||||||
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
|
|
||||||
t.InvokeAndVerifyIntResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestInt32ToFloat32Hashtable) {
|
|
||||||
HashtableDefaultGraphTest<int, float> t(
|
|
||||||
kTfLiteInt32, kTfLiteFloat32,
|
|
||||||
/*keys=*/{1, 2, 3}, /*values=*/{4.0f, 5.0f, 6.0f}, /*queries=*/{2, 3, 4},
|
|
||||||
/*default_value=*/-1.0f, /*table_size=*/3,
|
|
||||||
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
|
|
||||||
t.InvokeAndVerifyFloatResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestInt32ToStringHashtable) {
|
|
||||||
HashtableDefaultGraphTest<int, std::string> t(
|
|
||||||
kTfLiteInt32, kTfLiteString,
|
|
||||||
/*keys=*/{1, 2, 3}, /*values=*/{"a", "b", "c"}, /*queries=*/{2, 3, 4},
|
/*keys=*/{1, 2, 3}, /*values=*/{"a", "b", "c"}, /*queries=*/{2, 3, 4},
|
||||||
/*default_value=*/"d", /*table_size=*/3,
|
/*default_value=*/"d", /*table_size=*/3,
|
||||||
/*lookup_result=*/{"b", "c", "d"});
|
/*lookup_result=*/{"b", "c", "d"});
|
||||||
t.InvokeAndVerifyStringResult();
|
t.InvokeAndVerifyStringResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestStringToInt32Hashtable) {
|
TEST(HashtableOpsTest, TestStringToInt64Hashtable) {
|
||||||
HashtableDefaultGraphTest<std::string, int> t(
|
HashtableDefaultGraphTest<std::string, int64_t> t(
|
||||||
kTfLiteString, kTfLiteInt32,
|
kTfLiteString, kTfLiteInt64,
|
||||||
/*keys=*/{"A", "B", "C"}, /*values=*/{4, 5, 6},
|
/*keys=*/{"A", "B", "C"}, /*values=*/{4, 5, 6},
|
||||||
/*queries=*/{"B", "C", "D"},
|
/*queries=*/{"B", "C", "D"},
|
||||||
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
|
/*default_value=*/-1, /*table_size=*/3, /*lookup_result=*/{5, 6, -1});
|
||||||
t.InvokeAndVerifyIntResult();
|
t.InvokeAndVerifyIntResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestStringToFloat32Hashtable) {
|
|
||||||
HashtableDefaultGraphTest<std::string, float> t(
|
|
||||||
kTfLiteString, kTfLiteFloat32,
|
|
||||||
/*keys=*/{"A", "B", "C"}, /*values=*/{4.0f, 5.0f, 6.0f},
|
|
||||||
/*queries=*/{"B", "C", "D"},
|
|
||||||
/*default_value=*/-1.0f, /*table_size=*/3,
|
|
||||||
/*lookup_result=*/{5.0f, 6.0f, -1.0f});
|
|
||||||
t.InvokeAndVerifyFloatResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestStringToStringHashtable) {
|
|
||||||
HashtableDefaultGraphTest<std::string, std::string> t(
|
|
||||||
kTfLiteString, kTfLiteString,
|
|
||||||
/*keys=*/{"A", "B", "C"}, /*values=*/{"a", "b", "c"},
|
|
||||||
/*queries=*/{"B", "C", "D"},
|
|
||||||
/*default_value=*/"d", /*table_size=*/3,
|
|
||||||
/*lookup_result=*/{"b", "c", "d"});
|
|
||||||
t.InvokeAndVerifyStringResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestNoImport) {
|
TEST(HashtableOpsTest, TestNoImport) {
|
||||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||||
graph.SetQuery({1, 2, 3}, -1);
|
graph.SetQuery({"1", "2", "3"}, -1);
|
||||||
graph.AddTensors();
|
graph.AddTensors();
|
||||||
graph.BuildNoImportGraph();
|
graph.BuildNoImportGraph();
|
||||||
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
|
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
|
||||||
@ -607,9 +591,9 @@ TEST(HashtableOpsTest, TestNoImport) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestImportTwice) {
|
TEST(HashtableOpsTest, TestImportTwice) {
|
||||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||||
graph.SetTable({1, 2, 3}, {4, 5, 6});
|
graph.SetTable({"1", "2", "3"}, {4, 5, 6});
|
||||||
graph.SetQuery({2, 3, 4}, -1);
|
graph.SetQuery({"2", "3", "4"}, -1);
|
||||||
graph.AddTensors();
|
graph.AddTensors();
|
||||||
graph.BuildImportTwiceGraph();
|
graph.BuildImportTwiceGraph();
|
||||||
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
|
EXPECT_EQ(graph.AllocateTensors(), kTfLiteOk);
|
||||||
@ -621,11 +605,11 @@ TEST(HashtableOpsTest, TestImportTwice) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestTwoHashtables) {
|
TEST(HashtableOpsTest, TestTwoHashtables) {
|
||||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||||
graph.SetTable({1, 2, 3}, {4, 5, 6});
|
graph.SetTable({"1", "2", "3"}, {4, 5, 6});
|
||||||
graph.SetQuery({2, 3, 4}, -1);
|
graph.SetQuery({"2", "3", "4"}, -1);
|
||||||
graph.SetTableTwo({-1, -2, -3}, {7, 8, 9});
|
graph.SetTableTwo({"-1", "-2", "-3"}, {7, 8, 9});
|
||||||
graph.SetQueryForTableTwo({-4, -2, -3}, -2);
|
graph.SetQueryForTableTwo({"-4", "-2", "-3"}, -2);
|
||||||
graph.AddTensors(/*table_two_initialization=*/true);
|
graph.AddTensors(/*table_two_initialization=*/true);
|
||||||
graph.BuildTwoHashtablesGraph();
|
graph.BuildTwoHashtablesGraph();
|
||||||
EXPECT_EQ(graph.AllocateTensors(/*table_two_initialization=*/true),
|
EXPECT_EQ(graph.AllocateTensors(/*table_two_initialization=*/true),
|
||||||
@ -639,9 +623,9 @@ TEST(HashtableOpsTest, TestTwoHashtables) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
|
TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
|
||||||
HashtableGraph<int, int> graph(kTfLiteInt32, kTfLiteInt32);
|
HashtableGraph<std::string, std::int64_t> graph(kTfLiteString, kTfLiteInt64);
|
||||||
graph.SetTable({1, 2, 3}, {4, 5});
|
graph.SetTable({"1", "2", "3"}, {4, 5});
|
||||||
graph.SetQuery({2, 3, 4}, -1);
|
graph.SetQuery({"2", "3", "4"}, -1);
|
||||||
graph.AddTensors();
|
graph.AddTensors();
|
||||||
graph.BuildDefaultGraph();
|
graph.BuildDefaultGraph();
|
||||||
EXPECT_EQ(graph.AllocateTensors(), kTfLiteError);
|
EXPECT_EQ(graph.AllocateTensors(), kTfLiteError);
|
||||||
@ -650,16 +634,16 @@ TEST(HashtableOpsTest, TestImportDifferentKeyAndValueSize) {
|
|||||||
// HashtableOpModel creates a model with one signle Hashtable op.
|
// HashtableOpModel creates a model with one signle Hashtable op.
|
||||||
class HashtableOpModel : public SingleOpModel {
|
class HashtableOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
explicit HashtableOpModel(const char* table_name, TfLiteType key_dtype,
|
explicit HashtableOpModel(const char* table_name, TensorType key_dtype,
|
||||||
TfLiteType value_dtype) {
|
TensorType value_dtype) {
|
||||||
output_ = AddOutput(GetTensorType<int>());
|
output_ = AddOutput(GetTensorType<int>());
|
||||||
|
|
||||||
// Set up and pass in custom options using flexbuffer.
|
// Set up and pass in custom options using flexbuffer.
|
||||||
flexbuffers::Builder fbb;
|
flexbuffers::Builder fbb;
|
||||||
fbb.Map([&]() {
|
fbb.Map([&]() {
|
||||||
fbb.String("table_name", std::string(table_name));
|
fbb.String(kSharedNameStr, std::string(table_name));
|
||||||
fbb.Int("key_dtype", key_dtype);
|
fbb.Int(kKeyDtypeStr, key_dtype);
|
||||||
fbb.Int("value_dtype", value_dtype);
|
fbb.Int(kValueDtypeStr, value_dtype);
|
||||||
});
|
});
|
||||||
fbb.Finish();
|
fbb.Finish();
|
||||||
SetCustomOp("HASHTABLE", fbb.GetBuffer(),
|
SetCustomOp("HASHTABLE", fbb.GetBuffer(),
|
||||||
@ -679,7 +663,7 @@ class HashtableOpModel : public SingleOpModel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtable) {
|
TEST(HashtableOpsTest, TestHashtable) {
|
||||||
HashtableOpModel m("test_hashtable", kTfLiteInt32, kTfLiteString);
|
HashtableOpModel m("test_hashtable", TensorType_INT64, TensorType_STRING);
|
||||||
EXPECT_EQ(m.GetResources().size(), 0);
|
EXPECT_EQ(m.GetResources().size(), 0);
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||||
@ -689,12 +673,12 @@ TEST(HashtableOpsTest, TestHashtable) {
|
|||||||
EXPECT_NE(resource_id, 0);
|
EXPECT_NE(resource_id, 0);
|
||||||
auto* hashtable = resource::GetHashtableResource(&resources, resource_id);
|
auto* hashtable = resource::GetHashtableResource(&resources, resource_id);
|
||||||
EXPECT_TRUE(hashtable != nullptr);
|
EXPECT_TRUE(hashtable != nullptr);
|
||||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
|
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
|
||||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
|
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) {
|
TfLiteTensor CreateTensor(TfLiteType type, const std::vector<T>& vec) {
|
||||||
TfLiteTensor tensor = {};
|
TfLiteTensor tensor = {};
|
||||||
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
||||||
dims->data[0] = vec.size();
|
dims->data[0] = vec.size();
|
||||||
@ -715,6 +699,28 @@ TfLiteTensor CreateTensor(TfLiteType type, std::vector<T> vec) {
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
TfLiteTensor CreateTensor(TfLiteType type,
|
||||||
|
const std::vector<std::string>& vec) {
|
||||||
|
TfLiteTensor tensor = {};
|
||||||
|
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
||||||
|
dims->data[0] = vec.size();
|
||||||
|
tensor.dims = dims;
|
||||||
|
tensor.name = "";
|
||||||
|
tensor.params = {};
|
||||||
|
tensor.quantization = {kTfLiteNoQuantization, nullptr};
|
||||||
|
tensor.is_variable = false;
|
||||||
|
tensor.allocation_type = kTfLiteDynamic;
|
||||||
|
tensor.allocation = nullptr;
|
||||||
|
tensor.type = type;
|
||||||
|
DynamicBuffer buf;
|
||||||
|
for (std::string str : vec) {
|
||||||
|
buf.AddString(str.c_str(), str.size());
|
||||||
|
}
|
||||||
|
buf.WriteToTensor(&tensor, nullptr);
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename KeyType, typename ValueType>
|
template <typename KeyType, typename ValueType>
|
||||||
void InitHashtableResource(resource::ResourceMap* resources, int resource_id,
|
void InitHashtableResource(resource::ResourceMap* resources, int resource_id,
|
||||||
TfLiteType key_type, TfLiteType value_type,
|
TfLiteType key_type, TfLiteType value_type,
|
||||||
@ -772,12 +778,12 @@ class BaseHashtableOpModel : public SingleOpModel {
|
|||||||
TensorType value_type_;
|
TensorType value_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// HashtableLookupOpModel creates a model with a HashtableLookup op.
|
// HashtableFindOpModel creates a model with a HashtableLookup op.
|
||||||
template <typename KeyType, typename ValueType>
|
template <typename KeyType, typename ValueType>
|
||||||
class HashtableLookupOpModel : public BaseHashtableOpModel {
|
class HashtableFindOpModel : public BaseHashtableOpModel {
|
||||||
public:
|
public:
|
||||||
HashtableLookupOpModel(const TensorType key_type, const TensorType value_type,
|
HashtableFindOpModel(const TensorType key_type, const TensorType value_type,
|
||||||
int lookup_size) {
|
int lookup_size) {
|
||||||
key_type_ = key_type;
|
key_type_ = key_type;
|
||||||
value_type_ = value_type;
|
value_type_ = value_type;
|
||||||
|
|
||||||
@ -787,8 +793,8 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
|
|||||||
|
|
||||||
output_ = AddOutput({value_type, {lookup_size}});
|
output_ = AddOutput({value_type, {lookup_size}});
|
||||||
|
|
||||||
SetCustomOp("HASHTABLE_LOOKUP", {},
|
SetCustomOp("HASHTABLE_FIND", {},
|
||||||
tflite::ops::custom::Register_HASHTABLE_LOOKUP);
|
tflite::ops::custom::Register_HASHTABLE_FIND);
|
||||||
BuildInterpreter(
|
BuildInterpreter(
|
||||||
{GetShape(resource_id_), GetShape(lookup_), GetShape(default_value_)});
|
{GetShape(resource_id_), GetShape(lookup_), GetShape(default_value_)});
|
||||||
}
|
}
|
||||||
@ -797,46 +803,56 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
|
|||||||
PopulateTensor(lookup_, data);
|
PopulateTensor(lookup_, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetStringLookup(const std::vector<std::string>& data) {
|
||||||
|
PopulateStringTensor(lookup_, data);
|
||||||
|
}
|
||||||
|
|
||||||
void SetDefaultValue(const std::vector<ValueType>& data) {
|
void SetDefaultValue(const std::vector<ValueType>& data) {
|
||||||
PopulateTensor(default_value_, data);
|
PopulateTensor(default_value_, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetStringDefaultValue(const std::vector<std::string>& data) {
|
||||||
|
PopulateStringTensor(default_value_, data);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int lookup_;
|
int lookup_;
|
||||||
int default_value_;
|
int default_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
|
TEST(HashtableOpsTest, TestHashtableLookupStringToInt64) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableLookupOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
HashtableFindOpModel<std::string, std::int64_t> m(TensorType_STRING,
|
||||||
TensorType_INT32, 3);
|
TensorType_INT64, 3);
|
||||||
|
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetLookup({5, 6, 7});
|
m.SetStringLookup({"5", "6", "7"});
|
||||||
m.SetDefaultValue({4});
|
m.SetDefaultValue({4});
|
||||||
|
|
||||||
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
|
InitHashtableResource<std::string, std::int64_t>(
|
||||||
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
&m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
|
||||||
|
{"4", "5", "6"}, {1, 2, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
|
|
||||||
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({2, 3, 4}));
|
EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({2, 3, 4}));
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) {
|
TEST(HashtableOpsTest, TestHashtableLookupInt64ToString) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableLookupOpModel<std::int32_t, float> m(TensorType_INT32,
|
HashtableFindOpModel<std::int64_t, std::string> m(TensorType_INT64,
|
||||||
TensorType_FLOAT32, 3);
|
TensorType_STRING, 3);
|
||||||
|
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetLookup({5, 6, 7});
|
m.SetLookup({5, 6, 7});
|
||||||
m.SetDefaultValue({4.0f});
|
m.SetStringDefaultValue({"4"});
|
||||||
|
|
||||||
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
|
InitHashtableResource<std::int64_t, std::string>(
|
||||||
kTfLiteFloat32, {4, 5, 6}, {1.0f, 2.0f, 3.0f});
|
&m.GetResources(), kResourceId, kTfLiteInt64, kTfLiteString, {4, 5, 6},
|
||||||
|
{"1", "2", "3"});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
|
|
||||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({2.0f, 3.0f, 4.0f}));
|
EXPECT_THAT(m.GetOutput<std::string>(), ElementsAreArray({"2", "3", "4"}));
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -863,19 +879,27 @@ class HashtableImportOpModel : public BaseHashtableOpModel {
|
|||||||
PopulateTensor(keys_, data);
|
PopulateTensor(keys_, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetStringKeys(const std::vector<std::string>& data) {
|
||||||
|
PopulateStringTensor(keys_, data);
|
||||||
|
}
|
||||||
|
|
||||||
void SetValues(const std::vector<ValueType>& data) {
|
void SetValues(const std::vector<ValueType>& data) {
|
||||||
PopulateTensor(values_, data);
|
PopulateTensor(values_, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetStringValues(const std::vector<std::string>& data) {
|
||||||
|
PopulateStringTensor(values_, data);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableImport) {
|
TEST(HashtableOpsTest, TestHashtableImport) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
|
HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
|
||||||
TensorType_FLOAT32, 3);
|
TensorType_STRING, 3);
|
||||||
EXPECT_EQ(m.GetResources().size(), 0);
|
EXPECT_EQ(m.GetResources().size(), 0);
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetKeys({1, 2, 3});
|
m.SetKeys({1, 2, 3});
|
||||||
m.SetValues({1.0f, 2.0f, 3.0f});
|
m.SetStringValues({"1", "2", "3"});
|
||||||
m.CreateHashtableResource(kResourceId);
|
m.CreateHashtableResource(kResourceId);
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
|
|
||||||
@ -883,20 +907,20 @@ TEST(HashtableOpsTest, TestHashtableImport) {
|
|||||||
EXPECT_EQ(resources.size(), 1);
|
EXPECT_EQ(resources.size(), 1);
|
||||||
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
|
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
|
||||||
EXPECT_TRUE(hashtable != nullptr);
|
EXPECT_TRUE(hashtable != nullptr);
|
||||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
|
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
|
||||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32);
|
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
|
||||||
|
|
||||||
EXPECT_EQ(hashtable->Size(), 3);
|
EXPECT_EQ(hashtable->Size(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableImportTwice) {
|
TEST(HashtableOpsTest, TestHashtableImportTwice) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
|
HashtableImportOpModel<std::int64_t, std::string> m(TensorType_INT64,
|
||||||
TensorType_FLOAT32, 3);
|
TensorType_STRING, 3);
|
||||||
EXPECT_EQ(m.GetResources().size(), 0);
|
EXPECT_EQ(m.GetResources().size(), 0);
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetKeys({1, 2, 3});
|
m.SetKeys({1, 2, 3});
|
||||||
m.SetValues({1.0f, 2.0f, 3.0f});
|
m.SetStringValues({"1", "2", "3"});
|
||||||
m.CreateHashtableResource(kResourceId);
|
m.CreateHashtableResource(kResourceId);
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
@ -905,8 +929,8 @@ TEST(HashtableOpsTest, TestHashtableImportTwice) {
|
|||||||
EXPECT_EQ(resources.size(), 1);
|
EXPECT_EQ(resources.size(), 1);
|
||||||
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
|
auto* hashtable = resource::GetHashtableResource(&resources, kResourceId);
|
||||||
EXPECT_TRUE(hashtable != nullptr);
|
EXPECT_TRUE(hashtable != nullptr);
|
||||||
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt32);
|
EXPECT_TRUE(hashtable->GetKeyType() == kTfLiteInt64);
|
||||||
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteFloat32);
|
EXPECT_TRUE(hashtable->GetValueType() == kTfLiteString);
|
||||||
EXPECT_EQ(hashtable->Size(), 3);
|
EXPECT_EQ(hashtable->Size(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -920,7 +944,7 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
|
|||||||
|
|
||||||
resource_id_ = AddInput({TensorType_INT32, {1}});
|
resource_id_ = AddInput({TensorType_INT32, {1}});
|
||||||
|
|
||||||
output_ = AddOutput({TensorType_INT32, {1}});
|
output_ = AddOutput({TensorType_INT64, {1}});
|
||||||
|
|
||||||
SetCustomOp("HASHTABLE_SIZE", {},
|
SetCustomOp("HASHTABLE_SIZE", {},
|
||||||
tflite::ops::custom::Register_HASHTABLE_SIZE);
|
tflite::ops::custom::Register_HASHTABLE_SIZE);
|
||||||
@ -930,23 +954,24 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
|
|||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableSize) {
|
TEST(HashtableOpsTest, TestHashtableSize) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
|
||||||
TensorType_INT32);
|
TensorType_INT64);
|
||||||
|
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
|
|
||||||
InitHashtableResource(&m.GetResources(), kResourceId, kTfLiteInt32,
|
InitHashtableResource<std::string, std::int64_t>(
|
||||||
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
&m.GetResources(), kResourceId, kTfLiteString, kTfLiteInt64,
|
||||||
|
{"4", "5", "6"}, {1, 2, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
|
|
||||||
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({3}));
|
EXPECT_THAT(m.GetOutput<std::int64_t>(), ElementsAreArray({3}));
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
|
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
HashtableSizeOpModel<std::string, std::int64_t> m(TensorType_STRING,
|
||||||
TensorType_INT32);
|
TensorType_INT64);
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
|
|
||||||
// Invoke without hash table initialization.
|
// Invoke without hash table initialization.
|
||||||
|
@ -40,7 +40,7 @@ TfLiteStatus PrepareHashtableSize(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||||
TF_LITE_ENSURE(context, output_tensor != nullptr);
|
TF_LITE_ENSURE(context, output_tensor != nullptr);
|
||||||
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt32);
|
TF_LITE_ENSURE_EQ(context, output_tensor->type, kTfLiteInt64);
|
||||||
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
|
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
|
||||||
outputSize->data[0] = 1;
|
outputSize->data[0] = 1;
|
||||||
return context->ResizeTensor(context, output_tensor, outputSize);
|
return context->ResizeTensor(context, output_tensor, outputSize);
|
||||||
@ -52,7 +52,7 @@ TfLiteStatus EvalHashtableSize(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
int resource_id = input_resource_id_tensor->data.i32[0];
|
int resource_id = input_resource_id_tensor->data.i32[0];
|
||||||
|
|
||||||
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
|
||||||
auto* output_data = GetTensorData<int>(output_tensor);
|
auto* output_data = GetTensorData<std::int64_t>(output_tensor);
|
||||||
|
|
||||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||||
auto& resources = subgraph->resources();
|
auto& resources = subgraph->resources();
|
||||||
@ -69,9 +69,7 @@ TfLiteRegistration* Register_HASHTABLE_SIZE() {
|
|||||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
hashtable::PrepareHashtableSize,
|
hashtable::PrepareHashtableSize,
|
||||||
hashtable::EvalHashtableSize,
|
hashtable::EvalHashtableSize};
|
||||||
nullptr,
|
|
||||||
BuiltinOperator_CUSTOM};
|
|
||||||
return &r;
|
return &r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,33 +80,14 @@ TfLiteStatus StaticHashtable<KeyType, ValueType>::Import(
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename KeyType>
|
|
||||||
LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
|
|
||||||
TfLiteType value_type) {
|
|
||||||
switch (value_type) {
|
|
||||||
case kTfLiteInt32:
|
|
||||||
return new StaticHashtable<KeyType, std::int32_t>(key_type, value_type);
|
|
||||||
case kTfLiteString:
|
|
||||||
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
|
|
||||||
case kTfLiteFloat32:
|
|
||||||
return new StaticHashtable<KeyType, float>(key_type, value_type);
|
|
||||||
default:
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LookupInterface* CreateStaticHashtable(TfLiteType key_type,
|
LookupInterface* CreateStaticHashtable(TfLiteType key_type,
|
||||||
TfLiteType value_type) {
|
TfLiteType value_type) {
|
||||||
switch (key_type) {
|
if (key_type == kTfLiteInt64 && value_type == kTfLiteString) {
|
||||||
case kTfLiteInt32:
|
return new StaticHashtable<std::int64_t, std::string>(key_type, value_type);
|
||||||
return CreateStaticHashtableWithGivenKey<std::int32_t>(key_type,
|
} else if (key_type == kTfLiteString && value_type == kTfLiteInt64) {
|
||||||
value_type);
|
return new StaticHashtable<std::string, std::int64_t>(key_type, value_type);
|
||||||
case kTfLiteString:
|
|
||||||
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
|
|
||||||
value_type);
|
|
||||||
default:
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
@ -588,6 +588,7 @@ cc_library(
|
|||||||
":op_macros",
|
":op_macros",
|
||||||
"//tensorflow/lite:context",
|
"//tensorflow/lite:context",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
|
"//tensorflow/lite/experimental/kernels:hashtable_op_kernels",
|
||||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||||
"//tensorflow/lite/kernels/internal:tensor",
|
"//tensorflow/lite/kernels/internal:tensor",
|
||||||
"//third_party/fft2d:fft2d_headers",
|
"//third_party/fft2d:fft2d_headers",
|
||||||
|
@ -22,7 +22,10 @@ namespace ops {
|
|||||||
namespace custom {
|
namespace custom {
|
||||||
|
|
||||||
TfLiteRegistration* Register_RFFT2D();
|
TfLiteRegistration* Register_RFFT2D();
|
||||||
|
TfLiteRegistration* Register_HASHTABLE();
|
||||||
|
TfLiteRegistration* Register_HASHTABLE_FIND();
|
||||||
|
TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
||||||
|
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
||||||
}
|
}
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -322,6 +322,15 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
|||||||
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
||||||
buildinop_resolver_->AddCustom("RFFT2D",
|
buildinop_resolver_->AddCustom("RFFT2D",
|
||||||
tflite::ops::custom::Register_RFFT2D());
|
tflite::ops::custom::Register_RFFT2D());
|
||||||
|
buildinop_resolver_->AddCustom("HashTableV2",
|
||||||
|
tflite::ops::custom::Register_HASHTABLE());
|
||||||
|
buildinop_resolver_->AddCustom(
|
||||||
|
"LookupTableFindV2", tflite::ops::custom::Register_HASHTABLE_FIND());
|
||||||
|
buildinop_resolver_->AddCustom(
|
||||||
|
"LookupTableImportV2",
|
||||||
|
tflite::ops::custom::Register_HASHTABLE_IMPORT());
|
||||||
|
buildinop_resolver_->AddCustom(
|
||||||
|
"LookupTableSizeV2", tflite::ops::custom::Register_HASHTABLE_SIZE());
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (delegate_type) {
|
switch (delegate_type) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user