Create the reverse mapping of TfLiteType -> cpp types.
Also put the boilerplate in a macro similar to tensorflow/core/framework/types.h PiperOrigin-RevId: 321836334 Change-Id: Iac20ee3742d63938ba5bf84134757eabdd5e61c2
This commit is contained in:
parent
0a7b7b014f
commit
c49b3f570d
@ -624,6 +624,16 @@ cc_library(
|
||||
deps = ["//tensorflow/lite/c:common"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "type_to_tflitetype_test",
|
||||
size = "small",
|
||||
srcs = ["type_to_tflitetype_test.cc"],
|
||||
deps = [
|
||||
":type_to_tflitetype",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "minimal_logging_test",
|
||||
size = "small",
|
||||
|
@ -28,59 +28,43 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Map statically from a c++ type to a TfLiteType. Used in interpreter for safe
|
||||
// casts.
|
||||
template <class T>
|
||||
// Map statically from a C++ type to a TfLiteType. Used in interpreter for
|
||||
// safe casts.
|
||||
// Example:
|
||||
// typeToTfLiteType<bool>() -> kTfLiteBool
|
||||
template <typename T>
|
||||
constexpr TfLiteType typeToTfLiteType() {
|
||||
return kTfLiteNoType;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<int>() {
|
||||
return kTfLiteInt32;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<int16_t>() {
|
||||
return kTfLiteInt16;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<int64_t>() {
|
||||
return kTfLiteInt64;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<float>() {
|
||||
return kTfLiteFloat32;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<unsigned char>() {
|
||||
return kTfLiteUInt8;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<int8_t>() {
|
||||
return kTfLiteInt8;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<bool>() {
|
||||
return kTfLiteBool;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
|
||||
return kTfLiteComplex64;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<std::complex<double>>() {
|
||||
return kTfLiteComplex128;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<std::string>() {
|
||||
return kTfLiteString;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<TfLiteFloat16>() {
|
||||
return kTfLiteFloat16;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<double>() {
|
||||
return kTfLiteFloat64;
|
||||
}
|
||||
// Map from TfLiteType to the corresponding C++ type.
|
||||
// Example:
|
||||
// TfLiteTypeToType<kTfLiteBool>::Type -> bool
|
||||
template <TfLiteType TFLITE_TYPE_ENUM>
|
||||
struct TfLiteTypeToType {}; // Specializations below
|
||||
|
||||
// Template specialization for both typeToTfLiteType and TfLiteTypeToType.
|
||||
#define MATCH_TYPE_AND_TFLITE_TYPE(CPP_TYPE, TFLITE_TYPE_ENUM) \
|
||||
template <> \
|
||||
constexpr TfLiteType typeToTfLiteType<CPP_TYPE>() { \
|
||||
return TFLITE_TYPE_ENUM; \
|
||||
} \
|
||||
template <> \
|
||||
struct TfLiteTypeToType<TFLITE_TYPE_ENUM> { \
|
||||
using Type = CPP_TYPE; \
|
||||
}
|
||||
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(int, kTfLiteInt32);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(unsigned char, kTfLiteUInt8);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(int8_t, kTfLiteInt8);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(bool, kTfLiteBool);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<float>, kTfLiteComplex64);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<double>, kTfLiteComplex128);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(std::string, kTfLiteString);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(TfLiteFloat16, kTfLiteFloat16);
|
||||
MATCH_TYPE_AND_TFLITE_TYPE(double, kTfLiteFloat64);
|
||||
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_
|
||||
|
65
tensorflow/lite/type_to_tflitetype_test.cc
Normal file
65
tensorflow/lite/type_to_tflitetype_test.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/type_to_tflitetype.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
TEST(TypeToTfLiteType, TypeMapsAreInverseOfEachOther) {
|
||||
EXPECT_EQ(kTfLiteInt16,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt16>::Type>());
|
||||
EXPECT_EQ(kTfLiteInt32,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt32>::Type>());
|
||||
EXPECT_EQ(kTfLiteFloat32,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat32>::Type>());
|
||||
EXPECT_EQ(kTfLiteUInt8,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteUInt8>::Type>());
|
||||
EXPECT_EQ(kTfLiteInt8,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt8>::Type>());
|
||||
EXPECT_EQ(kTfLiteBool,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteBool>::Type>());
|
||||
EXPECT_EQ(kTfLiteComplex64,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteComplex64>::Type>());
|
||||
EXPECT_EQ(kTfLiteComplex128,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteComplex128>::Type>());
|
||||
EXPECT_EQ(kTfLiteString,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteString>::Type>());
|
||||
EXPECT_EQ(kTfLiteFloat16,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat16>::Type>());
|
||||
EXPECT_EQ(kTfLiteFloat64,
|
||||
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat64>::Type>());
|
||||
}
|
||||
|
||||
TEST(TypeToTfLiteType, Sanity) {
|
||||
EXPECT_EQ(kTfLiteFloat32, typeToTfLiteType<float>());
|
||||
EXPECT_EQ(kTfLiteBool, typeToTfLiteType<bool>());
|
||||
EXPECT_EQ(kTfLiteString, typeToTfLiteType<std::string>());
|
||||
static_assert(
|
||||
std::is_same<float, TfLiteTypeToType<kTfLiteFloat32>::Type>::value,
|
||||
"TfLiteTypeToType test failure");
|
||||
static_assert(std::is_same<bool, TfLiteTypeToType<kTfLiteBool>::Type>::value,
|
||||
"TfLiteTypeToType test failure");
|
||||
static_assert(
|
||||
std::is_same<std::string, TfLiteTypeToType<kTfLiteString>::Type>::value,
|
||||
"TfLiteTypeToType test failure");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user