diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 7b34fa120f0..0e89da4c3b8 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -252,6 +252,7 @@ cc_test( "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/testing:util", + "//third_party/eigen3", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/c/c_api_internal.c b/tensorflow/lite/c/c_api_internal.c index f20ee23bd81..926d992011f 100644 --- a/tensorflow/lite/c/c_api_internal.c +++ b/tensorflow/lite/c/c_api_internal.c @@ -172,6 +172,8 @@ const char* TfLiteTypeGetName(TfLiteType type) { return "COMPLEX64"; case kTfLiteString: return "STRING"; + case kTfLiteFloat16: + return "FLOAT16"; } return "Unknown type"; } diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index d9f08be0faa..1948e1ba106 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -195,6 +195,11 @@ typedef struct { float re, im; // real and imaginary parts, respectively. } TfLiteComplex64; +// Half precision data type compatible with the C99 definition. +typedef struct { + uint16_t data; +} TfLiteFloat16; + // Types supported by tensor typedef enum { kTfLiteNoType = 0, @@ -207,6 +212,7 @@ typedef enum { kTfLiteInt16 = 7, kTfLiteComplex64 = 8, kTfLiteInt8 = 9, + kTfLiteFloat16 = 10, } TfLiteType; // Return the name of a given type, for error reporting purposes. @@ -259,6 +265,8 @@ typedef union { int32_t* i32; int64_t* i64; float* f; + // Placeholder for 16b float type. Use uint16* in the pointer union for now. + TfLiteFloat16* f16; char* raw; const char* raw_const; uint8_t* uint8; diff --git a/tensorflow/lite/c/c_api_internal_test.cc b/tensorflow/lite/c/c_api_internal_test.cc index d01cf63a3e0..9a37cd9552f 100644 --- a/tensorflow/lite/c/c_api_internal_test.cc +++ b/tensorflow/lite/c/c_api_internal_test.cc @@ -78,6 +78,7 @@ TEST(Types, TestTypeNames) { }; EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE"); EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32"); + EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16"); EXPECT_EQ(type_name(kTfLiteInt16), "INT16"); EXPECT_EQ(type_name(kTfLiteInt32), "INT32"); EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8"); diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 2354f000a71..9d496f676f3 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -61,9 +61,8 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, *type = kTfLiteFloat32; break; case TensorType_FLOAT16: - error_reporter->Report("Unimplemented data type float16 in tensor\n", - tensor_type); - return kTfLiteError; + *type = kTfLiteFloat16; + break; case TensorType_INT16: *type = kTfLiteInt16; break; diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 4a5de48302c..c7f8c1ad66e 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -141,6 +141,13 @@ TEST_F(FlatbufferConversionsTest, TestConvertTensorType) { EXPECT_EQ(kTfLiteFloat32, type); } +TEST_F(FlatbufferConversionsTest, TestConvertTensorTypeFloat16) { + TfLiteType type; + EXPECT_EQ(kTfLiteOk, + ConvertTensorType(TensorType_FLOAT16, &type, &mock_reporter_)); + EXPECT_EQ(kTfLiteFloat16, type); +} + } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 082f57b808b..afa2d63f64f 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -469,6 +469,9 @@ TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims, case kTfLiteInt8: *bytes = sizeof(int8_t) * count; break; + case kTfLiteFloat16: + *bytes = sizeof(TfLiteFloat16) * count; + break; default: ReportError( "Only float32, int8, int16, int32, int64, uint8, bool, complex64 " diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc index c995b360f9d..4279f4ae397 100644 --- a/tensorflow/lite/delegates/flex/util.cc +++ b/tensorflow/lite/delegates/flex/util.cc @@ -60,6 +60,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { return TF_FLOAT; case kTfLiteFloat32: return TF_FLOAT; + case kTfLiteFloat16: + return TF_HALF; case kTfLiteInt16: return TF_INT16; case kTfLiteInt32: @@ -83,6 +85,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) { switch (type) { case TF_FLOAT: return kTfLiteFloat32; + case TF_HALF: + return kTfLiteFloat16; case TF_INT16: return kTfLiteInt16; case TF_INT32: diff --git a/tensorflow/lite/delegates/flex/util_test.cc b/tensorflow/lite/delegates/flex/util_test.cc index 87104751b81..69bba405055 100644 --- a/tensorflow/lite/delegates/flex/util_test.cc +++ b/tensorflow/lite/delegates/flex/util_test.cc @@ -101,9 +101,9 @@ TEST(UtilTest, CopyShapeAndType) { EXPECT_EQ( CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst), - kTfLiteError); - EXPECT_EQ(context.error, - "TF Lite does not support TensorFlow data type: half"); + kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteFloat16); } TEST(UtilTest, TypeConversionsFromTFLite) { diff --git a/tensorflow/lite/experimental/objc/apis/TFLTensor.h b/tensorflow/lite/experimental/objc/apis/TFLTensor.h index dc710abf4e2..fd781bd5723 100644 --- a/tensorflow/lite/experimental/objc/apis/TFLTensor.h +++ b/tensorflow/lite/experimental/objc/apis/TFLTensor.h @@ -29,6 +29,9 @@ typedef NS_ENUM(NSUInteger, TFLTensorDataType) { /** 32-bit single precision floating point. */ TFLTensorDataTypeFloat32, + /** 16-bit half precision floating point. */ + TFLTensorDataTypeFloat16, + /** 32-bit signed integer. */ TFLTensorDataTypeInt32, diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm index cf5a6b4c92b..1c8b7f976ec 100644 --- a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm +++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm @@ -366,6 +366,8 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_ switch (cTensorType) { case kTfLiteFloat32: return TFLTensorDataTypeFloat32; + case kTfLiteFloat16: + return TFLTensorDataTypeFloat16; case kTfLiteInt32: return TFLTensorDataTypeInt32; case kTfLiteUInt8: diff --git a/tensorflow/lite/experimental/writer/enum_mapping.h b/tensorflow/lite/experimental/writer/enum_mapping.h index 4556f7463f7..77f7b26cbc2 100644 --- a/tensorflow/lite/experimental/writer/enum_mapping.h +++ b/tensorflow/lite/experimental/writer/enum_mapping.h @@ -62,6 +62,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT32; // TODO(aselle): Consider an error. case kTfLiteFloat32: return TensorType_FLOAT32; + case kTfLiteFloat16: + return TensorType_FLOAT16; case kTfLiteInt32: return TensorType_INT32; case kTfLiteUInt8: diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index c54bcc8166f..fd6e8ddd404 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -30,6 +30,11 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/util.h" +// TODO(b/132087118): move static_assert to c_api_internal when compiled with +// C++. +static_assert(sizeof(TfLiteFloat16) == sizeof(uint16_t), + "Float 16 type must be 16 bits."); + namespace tflite { namespace { diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index fd4dbfa6614..bee55a8c461 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -74,6 +74,10 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteString; } +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteFloat16; +} // An interpreter for a graph of nodes that input and output from tensors. // Each node of the graph processes a set of input tensors and produces a // set of output Tensors. All inputs/output tensors are referenced by index. diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index 78c3d4ddc7f..0c0c32b4eed 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -165,7 +166,7 @@ TEST(BasicInterpreter, CheckAllocate) { } cases[] = { {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)}, - {kTfLiteInt16, sizeof(int16_t)}, + {kTfLiteInt16, sizeof(int16_t)}, {kTfLiteFloat16, sizeof(TfLiteFloat16)}, }; for (auto test : cases) { @@ -238,6 +239,8 @@ TEST(BasicInterpreter, CheckResize) { const uint8_t uint8s[] = {3, 4}; const int64_t int64s[] = {6, -7}; const int16_t int16s[] = {8, -9}; + const Eigen::half float16s[] = {Eigen::half_impl::float_to_half_rtne(-3.f), + Eigen::half_impl::float_to_half_rtne(-4.f)}; struct { TfLiteType type; @@ -249,6 +252,8 @@ TEST(BasicInterpreter, CheckResize) { {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, {kTfLiteInt16, sizeof(int16_t), reinterpret_cast(int16s)}, + {kTfLiteFloat16, sizeof(TfLiteFloat16), + reinterpret_cast(float16s)}, }; for (auto test : cases) { @@ -283,10 +288,8 @@ TEST(BasicInterpreter, CheckResize) { TEST(BasicInterpreter, CheckAlignment) { struct { TfLiteType type; - } cases[] = { - {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, - {kTfLiteInt64}, {kTfLiteInt16}, - }; + } cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, + {kTfLiteInt64}, {kTfLiteInt16}, {kTfLiteFloat16}}; for (auto test : cases) { Interpreter interpreter; diff --git a/tensorflow/lite/kernels/internal/tensor_ctypes.h b/tensorflow/lite/kernels/internal/tensor_ctypes.h index f77fae251d8..8ee95d4d5b3 100644 --- a/tensorflow/lite/kernels/internal/tensor_ctypes.h +++ b/tensorflow/lite/kernels/internal/tensor_ctypes.h @@ -66,6 +66,11 @@ inline const float* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.f : nullptr; } +template <> +inline const TfLiteFloat16* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f16 : nullptr; +} + template <> inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index dfc5783422a..44f8aa317e2 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -20,7 +20,6 @@ limitations under the License. #include #include - #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -568,6 +567,7 @@ class SingleOpTest : public ::testing::TestWithParam { template TensorType GetTensorType() { if (std::is_same::value) return TensorType_FLOAT32; + if (std::is_same::value) return TensorType_FLOAT16; if (std::is_same::value) return TensorType_INT32; if (std::is_same::value) return TensorType_UINT8; if (std::is_same::value) return TensorType_STRING; diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index 1113bf01b17..a59af3d680c 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -56,6 +56,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteInt16"; case kTfLiteComplex64: return "kTfLiteComplex64"; + case kTfLiteFloat16: + return "kTfLiteFloat16"; } return "(invalid)"; } diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc index 22ec88bafd5..110c3ac4e04 100644 --- a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc @@ -32,6 +32,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { switch (tf_lite_type) { case kTfLiteFloat32: return NPY_FLOAT32; + case kTfLiteFloat16: + return NPY_FLOAT16; case kTfLiteInt32: return NPY_INT32; case kTfLiteInt16: diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index 285935dc9df..8ea376c835a 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -61,6 +61,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT32; // TODO(b/129336260): No schema type for none. case kTfLiteFloat32: return TensorType_FLOAT32; + case kTfLiteFloat16: + return TensorType_FLOAT16; case kTfLiteInt32: return TensorType_INT32; case kTfLiteUInt8: diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 0331aa70208..3a0352f331c 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -31,6 +31,7 @@ from tensorflow.python.training.saver import export_meta_graph as _export_meta_g # Map of tf.dtypes to TFLite types_flag_pb2. _MAP_TF_TO_TFLITE_TYPES = { dtypes.float32: _types_pb2.FLOAT, + dtypes.float16: _types_pb2.FLOAT16, dtypes.int32: _types_pb2.INT32, dtypes.int64: _types_pb2.INT64, dtypes.string: _types_pb2.STRING, diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index cfb5ed365f6..65b53bc8afe 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -50,6 +50,8 @@ class UtilTest(test_util.TensorFlowTestCase): self.assertEqual( util.convert_dtype_to_tflite_type(dtypes.complex64), _types_pb2.COMPLEX64) + self.assertEqual( + util.convert_dtype_to_tflite_type(dtypes.half), _types_pb2.FLOAT16) with self.assertRaises(ValueError): util.convert_dtype_to_tflite_type(dtypes.bool) diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index fcee42c2294..67510c2b3b1 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -223,6 +223,7 @@ enum class ArrayDataType : uint8 { kUint64, // 10 kString, kComplex64, + kFloat16, }; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type diff --git a/tensorflow/lite/toco/types.proto b/tensorflow/lite/toco/types.proto index fa911b8a4c8..2c655517431 100644 --- a/tensorflow/lite/toco/types.proto +++ b/tensorflow/lite/toco/types.proto @@ -46,4 +46,7 @@ enum IODataType { // Int8, quantized based on QuantizationParameters in schema. INT8 = 9; + + // Half precision float, not quantized. + FLOAT16 = 10; }