Use int32_t instead of int32 in resource code

PiperOrigin-RevId: 283781525
Change-Id: I2f3d51fc934ea98a3ef1c3504cb4b4b29a157d5a
This commit is contained in:
Jared Duke 2019-12-04 10:17:57 -08:00 committed by TensorFlower Gardener
parent de4e14925e
commit 9eea1b6de5
4 changed files with 22 additions and 17 deletions

View File

@ -86,7 +86,8 @@ TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* resource_handle_tensor = TfLiteTensor* resource_handle_tensor =
GetOutput(context, node, kResourceHandleTensor); GetOutput(context, node, kResourceHandleTensor);
auto* resource_handle_data = GetTensorData<int32>(resource_handle_tensor); auto* resource_handle_data =
GetTensorData<std::int32_t>(resource_handle_tensor);
resource_handle_data[0] = resource_id; resource_handle_data[0] = resource_id;
Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_); Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);

View File

@ -808,7 +808,8 @@ class HashtableLookupOpModel : public BaseHashtableOpModel {
TEST(HashtableOpsTest, TestHashtableLookupIntToInt) { TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableLookupOpModel<int32, int32> m(TensorType_INT32, TensorType_INT32, 3); HashtableLookupOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
TensorType_INT32, 3);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetLookup({5, 6, 7}); m.SetLookup({5, 6, 7});
@ -818,14 +819,14 @@ TEST(HashtableOpsTest, TestHashtableLookupIntToInt) {
kTfLiteInt32, {4, 5, 6}, {1, 2, 3}); kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetOutput<int32>(), ElementsAreArray({2, 3, 4})); EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({2, 3, 4}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
} }
TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) { TEST(HashtableOpsTest, TestHashtableLookupIntToFloat) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableLookupOpModel<int32, float> m(TensorType_INT32, TensorType_FLOAT32, HashtableLookupOpModel<std::int32_t, float> m(TensorType_INT32,
3); TensorType_FLOAT32, 3);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetLookup({5, 6, 7}); m.SetLookup({5, 6, 7});
@ -869,8 +870,8 @@ class HashtableImportOpModel : public BaseHashtableOpModel {
TEST(HashtableOpsTest, TestHashtableImport) { TEST(HashtableOpsTest, TestHashtableImport) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableImportOpModel<int32, float> m(TensorType_INT32, TensorType_FLOAT32, HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
3); TensorType_FLOAT32, 3);
EXPECT_EQ(m.GetResources().size(), 0); EXPECT_EQ(m.GetResources().size(), 0);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetKeys({1, 2, 3}); m.SetKeys({1, 2, 3});
@ -890,8 +891,8 @@ TEST(HashtableOpsTest, TestHashtableImport) {
TEST(HashtableOpsTest, TestHashtableImportTwice) { TEST(HashtableOpsTest, TestHashtableImportTwice) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableImportOpModel<int32, float> m(TensorType_INT32, TensorType_FLOAT32, HashtableImportOpModel<std::int32_t, float> m(TensorType_INT32,
3); TensorType_FLOAT32, 3);
EXPECT_EQ(m.GetResources().size(), 0); EXPECT_EQ(m.GetResources().size(), 0);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
m.SetKeys({1, 2, 3}); m.SetKeys({1, 2, 3});
@ -929,7 +930,8 @@ class HashtableSizeOpModel : public BaseHashtableOpModel {
TEST(HashtableOpsTest, TestHashtableSize) { TEST(HashtableOpsTest, TestHashtableSize) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableSizeOpModel<int32, int32> m(TensorType_INT32, TensorType_INT32); HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
TensorType_INT32);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
@ -937,13 +939,14 @@ TEST(HashtableOpsTest, TestHashtableSize) {
kTfLiteInt32, {4, 5, 6}, {1, 2, 3}); kTfLiteInt32, {4, 5, 6}, {1, 2, 3});
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetOutput<int32>(), ElementsAreArray({3})); EXPECT_THAT(m.GetOutput<std::int32_t>(), ElementsAreArray({3}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
} }
TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) { TEST(HashtableOpsTest, TestHashtableSizeNonInitialized) {
const int kResourceId = 42; const int kResourceId = 42;
HashtableSizeOpModel<int32, int32> m(TensorType_INT32, TensorType_INT32); HashtableSizeOpModel<std::int32_t, std::int32_t> m(TensorType_INT32,
TensorType_INT32);
m.SetResourceId({kResourceId}); m.SetResourceId({kResourceId});
// Invoke without hash table initialization. // Invoke without hash table initialization.

View File

@ -15,11 +15,10 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_RESOURCE_RESOURCE_BASE_H_
#include <cstdint>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "tensorflow/lite/kernels/internal/compatibility.h"
namespace tflite { namespace tflite {
namespace resource { namespace resource {
@ -35,7 +34,8 @@ class ResourceBase {
}; };
/// WARNING: Experimental interface, subject to change. /// WARNING: Experimental interface, subject to change.
using ResourceMap = std::unordered_map<int32, std::unique_ptr<ResourceBase>>; using ResourceMap =
std::unordered_map<std::int32_t, std::unique_ptr<ResourceBase>>;
} // namespace resource } // namespace resource
} // namespace tflite } // namespace tflite

View File

@ -85,7 +85,7 @@ LookupInterface* CreateStaticHashtableWithGivenKey(TfLiteType key_type,
TfLiteType value_type) { TfLiteType value_type) {
switch (value_type) { switch (value_type) {
case kTfLiteInt32: case kTfLiteInt32:
return new StaticHashtable<KeyType, int32>(key_type, value_type); return new StaticHashtable<KeyType, std::int32_t>(key_type, value_type);
case kTfLiteString: case kTfLiteString:
return new StaticHashtable<KeyType, std::string>(key_type, value_type); return new StaticHashtable<KeyType, std::string>(key_type, value_type);
case kTfLiteFloat32: case kTfLiteFloat32:
@ -99,7 +99,8 @@ LookupInterface* CreateStaticHashtable(TfLiteType key_type,
TfLiteType value_type) { TfLiteType value_type) {
switch (key_type) { switch (key_type) {
case kTfLiteInt32: case kTfLiteInt32:
return CreateStaticHashtableWithGivenKey<int32>(key_type, value_type); return CreateStaticHashtableWithGivenKey<std::int32_t>(key_type,
value_type);
case kTfLiteString: case kTfLiteString:
return CreateStaticHashtableWithGivenKey<std::string>(key_type, return CreateStaticHashtableWithGivenKey<std::string>(key_type,
value_type); value_type);