Create the reverse mapping of TfLiteType -> cpp types.

Also put the boilerplate in a macro similar to tensorflow/core/framework/types.h

PiperOrigin-RevId: 321836334
Change-Id: Iac20ee3742d63938ba5bf84134757eabdd5e61c2
This commit is contained in:
A. Unique TensorFlower 2020-07-17 12:52:14 -07:00 committed by TensorFlower Gardener
parent 0a7b7b014f
commit c49b3f570d
3 changed files with 110 additions and 51 deletions

View File

@ -624,6 +624,16 @@ cc_library(
deps = ["//tensorflow/lite/c:common"],
)
cc_test(
name = "type_to_tflitetype_test",
size = "small",
srcs = ["type_to_tflitetype_test.cc"],
deps = [
":type_to_tflitetype",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "minimal_logging_test",
size = "small",

View File

@ -28,59 +28,43 @@ limitations under the License.
namespace tflite {
// Map statically from a c++ type to a TfLiteType. Used in interpreter for safe
// casts.
template <class T>
// Map statically from a C++ type to a TfLiteType. Used in interpreter for
// safe casts.
// Example:
// typeToTfLiteType<bool>() -> kTfLiteBool
template <typename T>
constexpr TfLiteType typeToTfLiteType() {
return kTfLiteNoType;
}
template <>
constexpr TfLiteType typeToTfLiteType<int>() {
return kTfLiteInt32;
}
template <>
constexpr TfLiteType typeToTfLiteType<int16_t>() {
return kTfLiteInt16;
}
template <>
constexpr TfLiteType typeToTfLiteType<int64_t>() {
return kTfLiteInt64;
}
template <>
constexpr TfLiteType typeToTfLiteType<float>() {
return kTfLiteFloat32;
}
template <>
constexpr TfLiteType typeToTfLiteType<unsigned char>() {
return kTfLiteUInt8;
}
template <>
constexpr TfLiteType typeToTfLiteType<int8_t>() {
return kTfLiteInt8;
}
template <>
constexpr TfLiteType typeToTfLiteType<bool>() {
return kTfLiteBool;
}
template <>
constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
return kTfLiteComplex64;
}
template <>
constexpr TfLiteType typeToTfLiteType<std::complex<double>>() {
return kTfLiteComplex128;
}
template <>
constexpr TfLiteType typeToTfLiteType<std::string>() {
return kTfLiteString;
}
template <>
constexpr TfLiteType typeToTfLiteType<TfLiteFloat16>() {
return kTfLiteFloat16;
}
template <>
constexpr TfLiteType typeToTfLiteType<double>() {
return kTfLiteFloat64;
}
// Map from TfLiteType to the corresponding C++ type.
// Example:
// TfLiteTypeToType<kTfLiteBool>::Type -> bool
template <TfLiteType TFLITE_TYPE_ENUM>
struct TfLiteTypeToType {}; // Specializations below
// Template specialization for both typeToTfLiteType and TfLiteTypeToType.
#define MATCH_TYPE_AND_TFLITE_TYPE(CPP_TYPE, TFLITE_TYPE_ENUM) \
template <> \
constexpr TfLiteType typeToTfLiteType<CPP_TYPE>() { \
return TFLITE_TYPE_ENUM; \
} \
template <> \
struct TfLiteTypeToType<TFLITE_TYPE_ENUM> { \
using Type = CPP_TYPE; \
}
MATCH_TYPE_AND_TFLITE_TYPE(int, kTfLiteInt32);
MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16);
MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64);
MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32);
MATCH_TYPE_AND_TFLITE_TYPE(unsigned char, kTfLiteUInt8);
MATCH_TYPE_AND_TFLITE_TYPE(int8_t, kTfLiteInt8);
MATCH_TYPE_AND_TFLITE_TYPE(bool, kTfLiteBool);
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<float>, kTfLiteComplex64);
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<double>, kTfLiteComplex128);
MATCH_TYPE_AND_TFLITE_TYPE(std::string, kTfLiteString);
MATCH_TYPE_AND_TFLITE_TYPE(TfLiteFloat16, kTfLiteFloat16);
MATCH_TYPE_AND_TFLITE_TYPE(double, kTfLiteFloat64);
} // namespace tflite
#endif // TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_

View File

@ -0,0 +1,65 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/type_to_tflitetype.h"
#include <string>
#include <gtest/gtest.h>
namespace tflite {
namespace {
TEST(TypeToTfLiteType, TypeMapsAreInverseOfEachOther) {
EXPECT_EQ(kTfLiteInt16,
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt16>::Type>());
EXPECT_EQ(kTfLiteInt32,
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt32>::Type>());
EXPECT_EQ(kTfLiteFloat32,
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat32>::Type>());
EXPECT_EQ(kTfLiteUInt8,
typeToTfLiteType<TfLiteTypeToType<kTfLiteUInt8>::Type>());
EXPECT_EQ(kTfLiteInt8,
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt8>::Type>());
EXPECT_EQ(kTfLiteBool,
typeToTfLiteType<TfLiteTypeToType<kTfLiteBool>::Type>());
EXPECT_EQ(kTfLiteComplex64,
typeToTfLiteType<TfLiteTypeToType<kTfLiteComplex64>::Type>());
EXPECT_EQ(kTfLiteComplex128,
typeToTfLiteType<TfLiteTypeToType<kTfLiteComplex128>::Type>());
EXPECT_EQ(kTfLiteString,
typeToTfLiteType<TfLiteTypeToType<kTfLiteString>::Type>());
EXPECT_EQ(kTfLiteFloat16,
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat16>::Type>());
EXPECT_EQ(kTfLiteFloat64,
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat64>::Type>());
}
TEST(TypeToTfLiteType, Sanity) {
EXPECT_EQ(kTfLiteFloat32, typeToTfLiteType<float>());
EXPECT_EQ(kTfLiteBool, typeToTfLiteType<bool>());
EXPECT_EQ(kTfLiteString, typeToTfLiteType<std::string>());
static_assert(
std::is_same<float, TfLiteTypeToType<kTfLiteFloat32>::Type>::value,
"TfLiteTypeToType test failure");
static_assert(std::is_same<bool, TfLiteTypeToType<kTfLiteBool>::Type>::value,
"TfLiteTypeToType test failure");
static_assert(
std::is_same<std::string, TfLiteTypeToType<kTfLiteString>::Type>::value,
"TfLiteTypeToType test failure");
}
} // namespace
} // namespace tflite