Use int32_t instead of int32 in resource code
PiperOrigin-RevId: 283781525 Change-Id: I2f3d51fc934ea98a3ef1c3504cb4b4b29a157d5a
This commit is contained in:
parent
de4e14925e
commit
9eea1b6de5
@ -86,7 +86,8 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TfLiteTensor* resource_handle_tensor =
|
TfLiteTensor* resource_handle_tensor =
|
||||||
GetOutput(context, node, kResourceHandleTensor);
|
GetOutput(context, node, kResourceHandleTensor);
|
||||||
auto* resource_handle_data = GetTensorData<int32>(resource_handle_tensor);
|
auto* resource_handle_data =
|
||||||
|
GetTensorData<std::int32_t>(resource_handle_tensor);
|
||||||
resource_handle_data[0] = resource_id;
|
resource_handle_data[0] = resource_id;
|
||||||
|
|
||||||
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||||
|
@ -808,7 +808,8 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
|
|||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
|
TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableLookupOpModel<int32, int32> m(TensorType_INT32, TensorType_INT32, 3);
|
HashtableLookupOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
||||||
|
TensorType_INT32, 3);
|
||||||
|
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetLookup({5, 6, 7});
|
m.SetLookup({5, 6, 7});
|
||||||
@ -818,14 +819,14 @@ TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
|
|||||||
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
|
|
||||||
EXPECT_THAT(m.GetOutput<int32>(), ElementsAreArray({2, 3, 4}));
|
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({2, 3, 4}));
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) {
|
TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableLookupOpModel<int32, float> m(TensorType_INT32, TensorType_FLOAT32,
|
HashtableLookupOpModel<std::int32_t, float> m(TensorType_INT32,
|
||||||
3);
|
TensorType_FLOAT32, 3);
|
||||||
|
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetLookup({5, 6, 7});
|
m.SetLookup({5, 6, 7});
|
||||||
@ -869,8 +870,8 @@ class HashtableImportOpModel : public BaseHashtableOpModel {
|
|||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableImport) {
|
TEST(HashtableOpsTest, TestHashtableImport) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableImportOpModel<int32, float> m(TensorType_INT32, TensorType_FLOAT32,
|
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
|
||||||
3);
|
TensorType_FLOAT32, 3);
|
||||||
EXPECT_EQ(m.GetResources().size(), 0);
|
EXPECT_EQ(m.GetResources().size(), 0);
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetKeys({1, 2, 3});
|
m.SetKeys({1, 2, 3});
|
||||||
@ -890,8 +891,8 @@ TEST(HashtableOpsTest, TestHashtableImport) {
|
|||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableImportTwice) {
|
TEST(HashtableOpsTest, TestHashtableImportTwice) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableImportOpModel<int32, float> m(TensorType_INT32, TensorType_FLOAT32,
|
HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
|
||||||
3);
|
TensorType_FLOAT32, 3);
|
||||||
EXPECT_EQ(m.GetResources().size(), 0);
|
EXPECT_EQ(m.GetResources().size(), 0);
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
m.SetKeys({1, 2, 3});
|
m.SetKeys({1, 2, 3});
|
||||||
@ -929,7 +930,8 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
|
|||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableSize) {
|
TEST(HashtableOpsTest, TestHashtableSize) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableSizeOpModel<int32, int32> m(TensorType_INT32, TensorType_INT32);
|
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
||||||
|
TensorType_INT32);
|
||||||
|
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
|
|
||||||
@ -937,13 +939,14 @@ TEST(HashtableOpsTest, TestHashtableSize) {
|
|||||||
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
|
|
||||||
EXPECT_THAT(m.GetOutput<int32>(), ElementsAreArray({3}));
|
EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({3}));
|
||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
|
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
|
||||||
const int kResourceId = 42;
|
const int kResourceId = 42;
|
||||||
HashtableSizeOpModel<int32, int32> m(TensorType_INT32, TensorType_INT32);
|
HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
|
||||||
|
TensorType_INT32);
|
||||||
m.SetResourceId({kResourceId});
|
m.SetResourceId({kResourceId});
|
||||||
|
|
||||||
// Invoke without hash table initialization.
|
// Invoke without hash table initialization.
|
||||||
|
@ -15,11 +15,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
|
||||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
|
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace resource {
|
namespace resource {
|
||||||
|
|
||||||
@ -35,7 +34,8 @@ class ResourceBase {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// WARNING: Experimental interface, subject to change.
|
/// WARNING: Experimental interface, subject to change.
|
||||||
using ResourceMap = std::unordered_map<int32, std::unique_ptr<ResourceBase>>;
|
using ResourceMap =
|
||||||
|
std::unordered_map<std::int32_t, std::unique_ptr<ResourceBase>>;
|
||||||
|
|
||||||
} // namespace resource
|
} // namespace resource
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -85,7 +85,7 @@ LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
|
|||||||
TfLiteType value_type) {
|
TfLiteType value_type) {
|
||||||
switch (value_type) {
|
switch (value_type) {
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
return new StaticHashtable<KeyType, int32>(key_type, value_type);
|
return new StaticHashtable<KeyType, std::int32_t>(key_type, value_type);
|
||||||
case kTfLiteString:
|
case kTfLiteString:
|
||||||
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
|
return new StaticHashtable<KeyType, std::string>(key_type, value_type);
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
@ -99,7 +99,8 @@ LookupInterface* CreateStaticHashtable(TfLiteType key_type,
|
|||||||
TfLiteType value_type) {
|
TfLiteType value_type) {
|
||||||
switch (key_type) {
|
switch (key_type) {
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
return CreateStaticHashtableWithGivenKey<int32>(key_type, value_type);
|
return CreateStaticHashtableWithGivenKey<std::int32_t>(key_type,
|
||||||
|
value_type);
|
||||||
case kTfLiteString:
|
case kTfLiteString:
|
||||||
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
|
return CreateStaticHashtableWithGivenKey<std::string>(key_type,
|
||||||
value_type);
|
value_type);
|
||||||
|
Loading…
Reference in New Issue
Block a user