Propagate half precision float through tflite

PiperOrigin-RevId: 247082214
This commit is contained in:
A. Unique TensorFlower 2019-05-07 13:43:58 -07:00 committed by TensorFlower Gardener
parent 289a0af9b0
commit c58ddf2520
24 changed files with 74 additions and 12 deletions

View File

@ -252,6 +252,7 @@ cc_test(
"//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:tensor_utils",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util", "//tensorflow/lite/testing:util",
"//third_party/eigen3",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
], ],
) )

View File

@ -172,6 +172,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "COMPLEX64"; return "COMPLEX64";
case kTfLiteString: case kTfLiteString:
return "STRING"; return "STRING";
case kTfLiteFloat16:
return "FLOAT16";
} }
return "Unknown type"; return "Unknown type";
} }

View File

@ -195,6 +195,11 @@ typedef struct {
float re, im; // real and imaginary parts, respectively. float re, im; // real and imaginary parts, respectively.
} TfLiteComplex64; } TfLiteComplex64;
// Half precision data type compatible with the C99 definition.
typedef struct {
uint16_t data;
} TfLiteFloat16;
// Types supported by tensor // Types supported by tensor
typedef enum { typedef enum {
kTfLiteNoType = 0, kTfLiteNoType = 0,
@ -207,6 +212,7 @@ typedef enum {
kTfLiteInt16 = 7, kTfLiteInt16 = 7,
kTfLiteComplex64 = 8, kTfLiteComplex64 = 8,
kTfLiteInt8 = 9, kTfLiteInt8 = 9,
kTfLiteFloat16 = 10,
} TfLiteType; } TfLiteType;
// Return the name of a given type, for error reporting purposes. // Return the name of a given type, for error reporting purposes.
@ -259,6 +265,8 @@ typedef union {
int32_t* i32; int32_t* i32;
int64_t* i64; int64_t* i64;
float* f; float* f;
// Placeholder for 16b float type. Use uint16* in the pointer union for now.
TfLiteFloat16* f16;
char* raw; char* raw;
const char* raw_const; const char* raw_const;
uint8_t* uint8; uint8_t* uint8;

View File

@ -78,6 +78,7 @@ TEST(Types, TestTypeNames) {
}; };
EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE"); EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE");
EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32"); EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32");
EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16");
EXPECT_EQ(type_name(kTfLiteInt16), "INT16"); EXPECT_EQ(type_name(kTfLiteInt16), "INT16");
EXPECT_EQ(type_name(kTfLiteInt32), "INT32"); EXPECT_EQ(type_name(kTfLiteInt32), "INT32");
EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8"); EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8");

View File

@ -61,9 +61,8 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
*type = kTfLiteFloat32; *type = kTfLiteFloat32;
break; break;
case TensorType_FLOAT16: case TensorType_FLOAT16:
error_reporter->Report("Unimplemented data type float16 in tensor\n", *type = kTfLiteFloat16;
tensor_type); break;
return kTfLiteError;
case TensorType_INT16: case TensorType_INT16:
*type = kTfLiteInt16; *type = kTfLiteInt16;
break; break;

View File

@ -141,6 +141,13 @@ TEST_F(FlatbufferConversionsTest, TestConvertTensorType) {
EXPECT_EQ(kTfLiteFloat32, type); EXPECT_EQ(kTfLiteFloat32, type);
} }
TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeFloat16) {
TfLiteType type;
EXPECT_EQ(kTfLiteOk,
ConvertTensorType(TensorType_FLOAT16, &type, &mock_reporter_));
EXPECT_EQ(kTfLiteFloat16, type);
}
} // namespace tflite } // namespace tflite
int main(int argc, char** argv) { int main(int argc, char** argv) {

View File

@ -469,6 +469,9 @@ TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,
case kTfLiteInt8: case kTfLiteInt8:
*bytes = sizeof(int8_t) * count; *bytes = sizeof(int8_t) * count;
break; break;
case kTfLiteFloat16:
*bytes = sizeof(TfLiteFloat16) * count;
break;
default: default:
ReportError( ReportError(
"Only float32, int8, int16, int32, int64, uint8, bool, complex64 " "Only float32, int8, int16, int32, int64, uint8, bool, complex64 "

View File

@ -60,6 +60,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
return TF_FLOAT; return TF_FLOAT;
case kTfLiteFloat32: case kTfLiteFloat32:
return TF_FLOAT; return TF_FLOAT;
case kTfLiteFloat16:
return TF_HALF;
case kTfLiteInt16: case kTfLiteInt16:
return TF_INT16; return TF_INT16;
case kTfLiteInt32: case kTfLiteInt32:
@ -83,6 +85,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
switch (type) { switch (type) {
case TF_FLOAT: case TF_FLOAT:
return kTfLiteFloat32; return kTfLiteFloat32;
case TF_HALF:
return kTfLiteFloat16;
case TF_INT16: case TF_INT16:
return kTfLiteInt16; return kTfLiteInt16;
case TF_INT32: case TF_INT32:

View File

@ -101,9 +101,9 @@ TEST(UtilTest, CopyShapeAndType) {
EXPECT_EQ( EXPECT_EQ(
CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst), CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst),
kTfLiteError); kTfLiteOk);
EXPECT_EQ(context.error, EXPECT_THAT(context.new_size, ElementsAre(1, 2));
"TF Lite does not support TensorFlow data type: half"); EXPECT_EQ(dst.type, kTfLiteFloat16);
} }
TEST(UtilTest, TypeConversionsFromTFLite) { TEST(UtilTest, TypeConversionsFromTFLite) {

View File

@ -29,6 +29,9 @@ typedef NS_ENUM(NSUInteger, TFLTensorDataType) {
/** 32-bit single precision floating point. */ /** 32-bit single precision floating point. */
TFLTensorDataTypeFloat32, TFLTensorDataTypeFloat32,
/** 16-bit half precision floating point. */
TFLTensorDataTypeFloat16,
/** 32-bit signed integer. */ /** 32-bit signed integer. */
TFLTensorDataTypeInt32, TFLTensorDataTypeInt32,

View File

@ -366,6 +366,8 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
switch (cTensorType) { switch (cTensorType) {
case kTfLiteFloat32: case kTfLiteFloat32:
return TFLTensorDataTypeFloat32; return TFLTensorDataTypeFloat32;
case kTfLiteFloat16:
return TFLTensorDataTypeFloat16;
case kTfLiteInt32: case kTfLiteInt32:
return TFLTensorDataTypeInt32; return TFLTensorDataTypeInt32;
case kTfLiteUInt8: case kTfLiteUInt8:

View File

@ -62,6 +62,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_FLOAT32; // TODO(aselle): Consider an error. return TensorType_FLOAT32; // TODO(aselle): Consider an error.
case kTfLiteFloat32: case kTfLiteFloat32:
return TensorType_FLOAT32; return TensorType_FLOAT32;
case kTfLiteFloat16:
return TensorType_FLOAT16;
case kTfLiteInt32: case kTfLiteInt32:
return TensorType_INT32; return TensorType_INT32;
case kTfLiteUInt8: case kTfLiteUInt8:

View File

@ -30,6 +30,11 @@ limitations under the License.
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/util.h" #include "tensorflow/lite/util.h"
// TODO(b/132087118): move static_assert to c_api_internal when compiled with
// C++.
static_assert(sizeof(TfLiteFloat16) == sizeof(uint16_t),
"Float 16 type must be 16 bits.");
namespace tflite { namespace tflite {
namespace { namespace {

View File

@ -74,6 +74,10 @@ constexpr TfLiteType typeToTfLiteType<string>() {
return kTfLiteString; return kTfLiteString;
} }
template <>
constexpr TfLiteType typeToTfLiteType<TfLiteFloat16>() {
return kTfLiteFloat16;
}
// An interpreter for a graph of nodes that input and output from tensors. // An interpreter for a graph of nodes that input and output from tensors.
// Each node of the graph processes a set of input tensors and produces a // Each node of the graph processes a set of input tensors and produces a
// set of output Tensors. All inputs/output tensors are referenced by index. // set of output Tensors. All inputs/output tensors are referenced by index.

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
@ -165,7 +166,7 @@ TEST(BasicInterpreter, CheckAllocate) {
} cases[] = { } cases[] = {
{kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
{kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)}, {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)},
{kTfLiteInt16, sizeof(int16_t)}, {kTfLiteInt16, sizeof(int16_t)}, {kTfLiteFloat16, sizeof(TfLiteFloat16)},
}; };
for (auto test : cases) { for (auto test : cases) {
@ -238,6 +239,8 @@ TEST(BasicInterpreter, CheckResize) {
const uint8_t uint8s[] = {3, 4}; const uint8_t uint8s[] = {3, 4};
const int64_t int64s[] = {6, -7}; const int64_t int64s[] = {6, -7};
const int16_t int16s[] = {8, -9}; const int16_t int16s[] = {8, -9};
const Eigen::half float16s[] = {Eigen::half_impl::float_to_half_rtne(-3.f),
Eigen::half_impl::float_to_half_rtne(-4.f)};
struct { struct {
TfLiteType type; TfLiteType type;
@ -249,6 +252,8 @@ TEST(BasicInterpreter, CheckResize) {
{kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)}, {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
{kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)}, {kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
{kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)}, {kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)},
{kTfLiteFloat16, sizeof(TfLiteFloat16),
reinterpret_cast<const char*>(float16s)},
}; };
for (auto test : cases) { for (auto test : cases) {
@ -283,10 +288,8 @@ TEST(BasicInterpreter, CheckResize) {
TEST(BasicInterpreter, CheckAlignment) { TEST(BasicInterpreter, CheckAlignment) {
struct { struct {
TfLiteType type; TfLiteType type;
} cases[] = { } cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8},
{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, {kTfLiteInt64}, {kTfLiteInt16}, {kTfLiteFloat16}};
{kTfLiteInt64}, {kTfLiteInt16},
};
for (auto test : cases) { for (auto test : cases) {
Interpreter interpreter; Interpreter interpreter;

View File

@ -66,6 +66,11 @@ inline const float* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.f : nullptr; return tensor != nullptr ? tensor->data.f : nullptr;
} }
template <>
inline const TfLiteFloat16* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.f16 : nullptr;
}
template <> template <>
inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.uint8 : nullptr; return tensor != nullptr ? tensor->data.uint8 : nullptr;

View File

@ -20,7 +20,6 @@ limitations under the License.
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h"
@ -568,6 +567,7 @@ class SingleOpTest : public ::testing::TestWithParam<string> {
template <typename T> template <typename T>
TensorType GetTensorType() { TensorType GetTensorType() {
if (std::is_same<T, float>::value) return TensorType_FLOAT32; if (std::is_same<T, float>::value) return TensorType_FLOAT32;
if (std::is_same<T, TfLiteFloat16>::value) return TensorType_FLOAT16;
if (std::is_same<T, int32_t>::value) return TensorType_INT32; if (std::is_same<T, int32_t>::value) return TensorType_INT32;
if (std::is_same<T, uint8_t>::value) return TensorType_UINT8; if (std::is_same<T, uint8_t>::value) return TensorType_UINT8;
if (std::is_same<T, string>::value) return TensorType_STRING; if (std::is_same<T, string>::value) return TensorType_STRING;

View File

@ -56,6 +56,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteInt16"; return "kTfLiteInt16";
case kTfLiteComplex64: case kTfLiteComplex64:
return "kTfLiteComplex64"; return "kTfLiteComplex64";
case kTfLiteFloat16:
return "kTfLiteFloat16";
} }
return "(invalid)"; return "(invalid)";
} }

View File

@ -32,6 +32,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
switch (tf_lite_type) { switch (tf_lite_type) {
case kTfLiteFloat32: case kTfLiteFloat32:
return NPY_FLOAT32; return NPY_FLOAT32;
case kTfLiteFloat16:
return NPY_FLOAT16;
case kTfLiteInt32: case kTfLiteInt32:
return NPY_INT32; return NPY_INT32;
case kTfLiteInt16: case kTfLiteInt16:

View File

@ -61,6 +61,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_FLOAT32; // TODO(b/129336260): No schema type for none. return TensorType_FLOAT32; // TODO(b/129336260): No schema type for none.
case kTfLiteFloat32: case kTfLiteFloat32:
return TensorType_FLOAT32; return TensorType_FLOAT32;
case kTfLiteFloat16:
return TensorType_FLOAT16;
case kTfLiteInt32: case kTfLiteInt32:
return TensorType_INT32; return TensorType_INT32;
case kTfLiteUInt8: case kTfLiteUInt8:

View File

@ -31,6 +31,7 @@ from tensorflow.python.training.saver import export_meta_graph as _export_meta_g
# Map of tf.dtypes to TFLite types_flag_pb2. # Map of tf.dtypes to TFLite types_flag_pb2.
_MAP_TF_TO_TFLITE_TYPES = { _MAP_TF_TO_TFLITE_TYPES = {
dtypes.float32: _types_pb2.FLOAT, dtypes.float32: _types_pb2.FLOAT,
dtypes.float16: _types_pb2.FLOAT16,
dtypes.int32: _types_pb2.INT32, dtypes.int32: _types_pb2.INT32,
dtypes.int64: _types_pb2.INT64, dtypes.int64: _types_pb2.INT64,
dtypes.string: _types_pb2.STRING, dtypes.string: _types_pb2.STRING,

View File

@ -50,6 +50,8 @@ class UtilTest(test_util.TensorFlowTestCase):
self.assertEqual( self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.complex64), util.convert_dtype_to_tflite_type(dtypes.complex64),
_types_pb2.COMPLEX64) _types_pb2.COMPLEX64)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.half), _types_pb2.FLOAT16)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
util.convert_dtype_to_tflite_type(dtypes.bool) util.convert_dtype_to_tflite_type(dtypes.bool)

View File

@ -223,6 +223,7 @@ enum class ArrayDataType : uint8 {
kUint64, // 10 kUint64, // 10
kString, kString,
kComplex64, kComplex64,
kFloat16,
}; };
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type

View File

@ -46,4 +46,7 @@ enum IODataType {
// Int8, quantized based on QuantizationParameters in schema. // Int8, quantized based on QuantizationParameters in schema.
INT8 = 9; INT8 = 9;
// Half precision float, not quantized.
FLOAT16 = 10;
} }