Avoid inclusion of C++ string header in Micro to help with platform porting

PiperOrigin-RevId: 326753889
Change-Id: I6b93e10b7151c3e44c4d6bf97911359a94a3e839
This commit is contained in:
Pete Warden 2020-08-14 16:45:37 -07:00 committed by TensorFlower Gardener
parent 2d1e9501e3
commit 0c32f37be5
15 changed files with 330 additions and 217 deletions

View File

@ -55,6 +55,14 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "tf_lite_static_memory",
values = {
"copt": "-DTF_LITE_STATIC_MEMORY",
"cpu": "k8",
},
)
TFLITE_DEFAULT_COPTS = if_not_windows([
"-Wall",
"-Wno-comment",
@ -616,7 +624,14 @@ cc_library(
cc_library(
name = "type_to_tflitetype",
hdrs = ["type_to_tflitetype.h"],
hdrs = [
"portable_type_to_tflitetype.h",
] + select({
":tf_lite_static_memory": [],
"//conditions:default": [
"type_to_tflitetype.h",
],
}),
deps = ["//tensorflow/lite/c:common"],
)

View File

@ -490,6 +490,7 @@ cc_library(
"reference/integer_ops/mean.h",
"reference/integer_ops/transpose_conv.h",
"reference/reference_ops.h",
"reference/string_comparisons.h",
"reference/sparse_ops/fully_connected.h",
],
}),
@ -561,6 +562,7 @@ cc_library(
"reference/round.h",
"reference/softmax.h",
"reference/strided_slice.h",
"reference/string_comparisons.h",
"reference/sub.h",
"reference/tanh.h",
],
@ -598,9 +600,14 @@ cc_library(
cc_library(
name = "tensor",
hdrs = [
"tensor.h",
"portable_tensor.h",
"tensor_ctypes.h",
],
] + select({
":tf_lite_static_memory": [],
"//conditions:default": [
"tensor.h",
],
}),
copts = tflite_copts(),
deps = [
":types",
@ -613,9 +620,14 @@ cc_library(
cc_library(
name = "reference",
hdrs = [
"tensor.h",
"portable_tensor.h",
"tensor_ctypes.h",
],
] + select({
":tf_lite_static_memory": [],
"//conditions:default": [
"tensor.h",
],
}),
copts = tflite_copts(),
deps = [
":types",

View File

@ -0,0 +1,123 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_
#include <complex>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
class VectorOfTensors {
public:
// Build with the tensors in 'tensor_list'.
VectorOfTensors(const TfLiteContext& context,
const TfLiteIntArray& tensor_list) {
int num_tensors = tensor_list.size;
all_data_.reserve(num_tensors);
all_shape_.reserve(num_tensors);
all_shape_ptr_.reserve(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
all_data_.push_back(GetTensorData<T>(t));
all_shape_.push_back(GetTensorShape(t));
}
// Taking the pointer from inside a std::vector is only OK if the vector is
// never modified, so we populate all_shape in the previous loop and then we
// are free to grab iterators here.
for (int i = 0; i < num_tensors; ++i) {
all_shape_ptr_.push_back(&all_shape_[i]);
}
}
// Return a pointer to the data pointers of all tensors in the list. For
// example:
// float* const* f = v.data();
// f[0][1] is the second element of the first tensor.
T* const* data() const { return all_data_.data(); }
// Return a pointer the shape pointers of all tensors in the list. For
// example:
// const RuntimeShape* const* d = v.dims();
// dims[1] are the dimensions of the second tensor in the list.
const RuntimeShape* const* shapes() const { return all_shape_ptr_.data(); }
private:
std::vector<T*> all_data_;
std::vector<RuntimeShape> all_shape_;
std::vector<RuntimeShape*> all_shape_ptr_;
};
// A list of quantized tensors in a format that can be used by kernels like
// split and concatenation.
class VectorOfQuantizedTensors : public VectorOfTensors<uint8_t> {
public:
// Build with the tensors in 'tensor_list'.
VectorOfQuantizedTensors(const TfLiteContext& context,
const TfLiteIntArray& tensor_list)
: VectorOfTensors<uint8_t>(context, tensor_list) {
for (int i = 0; i < tensor_list.size; ++i) {
TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
zero_point_.push_back(t->params.zero_point);
scale_.push_back(t->params.scale);
}
}
const float* scale() const { return scale_.data(); }
const int32_t* zero_point() const { return zero_point_.data(); }
private:
std::vector<int32_t> zero_point_;
std::vector<float> scale_;
};
// Writes randomly accessed values from `input` sequentially into `output`.
template <typename T>
class SequentialTensorWriter {
public:
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) {
input_data_ = GetTensorData<T>(input);
output_ptr_ = GetTensorData<T>(output);
}
SequentialTensorWriter(const T* input_data, T* output_data)
: input_data_(input_data), output_ptr_(output_data) {}
void Write(int position) { *output_ptr_++ = input_data_[position]; }
void WriteN(int position, int len) {
memcpy(output_ptr_, &input_data_[position], sizeof(T) * len);
output_ptr_ += len;
}
private:
const T* input_data_;
T* output_ptr_;
};
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_PORTABLE_TENSOR_H_

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
@ -51,18 +50,6 @@ inline bool LessEqualFn(T lhs, T rhs) {
return lhs <= rhs;
}
inline bool StringRefEqualFn(const StringRef& lhs, const StringRef& rhs) {
if (lhs.len != rhs.len) return false;
for (int i = 0; i < lhs.len; ++i) {
if (lhs.str[i] != rhs.str[i]) return false;
}
return true;
}
inline bool StringRefNotEqualFn(const StringRef& lhs, const StringRef& rhs) {
return !StringRefEqualFn(lhs, rhs);
}
template <typename T>
using ComparisonFn = bool (*)(T, T);
@ -78,22 +65,6 @@ inline void ComparisonImpl(
}
}
inline void ComparisonStringImpl(bool (*F)(const StringRef&, const StringRef&),
const RuntimeShape& input1_shape,
const TfLiteTensor* input1,
const RuntimeShape& input2_shape,
const TfLiteTensor* input2,
const RuntimeShape& output_shape,
bool* output_data) {
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const auto lhs = GetString(input1, i);
const auto rhs = GetString(input2, i);
output_data[i] = F(lhs, rhs);
}
}
template <ComparisonFn<float> F>
inline void Comparison(const ComparisonParams& op_params,
const RuntimeShape& input1_shape,
@ -180,31 +151,6 @@ inline void BroadcastComparison4DSlowImpl(
}
}
inline void BroadcastComparison4DSlowStringImpl(
bool (*F)(const StringRef&, const StringRef&),
const RuntimeShape& unextended_input1_shape, const TfLiteTensor* input1,
const RuntimeShape& unextended_input2_shape, const TfLiteTensor* input2,
const RuntimeShape& unextended_output_shape, bool* output_data) {
const BroadcastComparison4DSlowCommon dims =
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
unextended_input2_shape,
unextended_output_shape);
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
const auto lhs =
GetString(input1, SubscriptToIndex(dims.desc1, b, y, x, c));
const auto rhs =
GetString(input2, SubscriptToIndex(dims.desc2, b, y, x, c));
output_data[Offset(dims.output_shape, b, y, x, c)] = F(lhs, rhs);
}
}
}
}
}
template <ComparisonFn<float> F>
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
const RuntimeShape& input1_shape,

View File

@ -59,6 +59,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/round.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
#include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
#include "tensorflow/lite/kernels/internal/reference/sub.h"
#include "tensorflow/lite/kernels/internal/reference/tanh.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"

View File

@ -0,0 +1,84 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRING_COMPARISONS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRING_COMPARISONS_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace reference_ops {
inline bool StringRefEqualFn(const StringRef& lhs, const StringRef& rhs) {
if (lhs.len != rhs.len) return false;
for (int i = 0; i < lhs.len; ++i) {
if (lhs.str[i] != rhs.str[i]) return false;
}
return true;
}
inline bool StringRefNotEqualFn(const StringRef& lhs, const StringRef& rhs) {
return !StringRefEqualFn(lhs, rhs);
}
inline void ComparisonStringImpl(bool (*F)(const StringRef&, const StringRef&),
const RuntimeShape& input1_shape,
const TfLiteTensor* input1,
const RuntimeShape& input2_shape,
const TfLiteTensor* input2,
const RuntimeShape& output_shape,
bool* output_data) {
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const auto lhs = GetString(input1, i);
const auto rhs = GetString(input2, i);
output_data[i] = F(lhs, rhs);
}
}
inline void BroadcastComparison4DSlowStringImpl(
bool (*F)(const StringRef&, const StringRef&),
const RuntimeShape& unextended_input1_shape, const TfLiteTensor* input1,
const RuntimeShape& unextended_input2_shape, const TfLiteTensor* input2,
const RuntimeShape& unextended_output_shape, bool* output_data) {
const BroadcastComparison4DSlowCommon dims =
BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
unextended_input2_shape,
unextended_output_shape);
for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
const auto lhs =
GetString(input1, SubscriptToIndex(dims.desc1, b, y, x, c));
const auto rhs =
GetString(input2, SubscriptToIndex(dims.desc2, b, y, x, c));
output_data[Offset(dims.output_shape, b, y, x, c)] = F(lhs, rhs);
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRING_COMPARISONS_H_

View File

@ -15,112 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_
#include <complex>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
// Most functionality has been moved into a version of this file that doesn't
// rely on std::string, so that it can be used in TFL Micro.
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
class VectorOfTensors {
public:
// Build with the tensors in 'tensor_list'.
VectorOfTensors(const TfLiteContext& context,
const TfLiteIntArray& tensor_list) {
int num_tensors = tensor_list.size;
all_data_.reserve(num_tensors);
all_shape_.reserve(num_tensors);
all_shape_ptr_.reserve(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
all_data_.push_back(GetTensorData<T>(t));
all_shape_.push_back(GetTensorShape(t));
}
// Taking the pointer from inside a std::vector is only OK if the vector is
// never modified, so we populate all_shape in the previous loop and then we
// are free to grab iterators here.
for (int i = 0; i < num_tensors; ++i) {
all_shape_ptr_.push_back(&all_shape_[i]);
}
}
// Return a pointer to the data pointers of all tensors in the list. For
// example:
// float* const* f = v.data();
// f[0][1] is the second element of the first tensor.
T* const* data() const { return all_data_.data(); }
// Return a pointer the shape pointers of all tensors in the list. For
// example:
// const RuntimeShape* const* d = v.dims();
// dims[1] are the dimensions of the second tensor in the list.
const RuntimeShape* const* shapes() const { return all_shape_ptr_.data(); }
private:
std::vector<T*> all_data_;
std::vector<RuntimeShape> all_shape_;
std::vector<RuntimeShape*> all_shape_ptr_;
};
// A list of quantized tensors in a format that can be used by kernels like
// split and concatenation.
class VectorOfQuantizedTensors : public VectorOfTensors<uint8_t> {
public:
// Build with the tensors in 'tensor_list'.
VectorOfQuantizedTensors(const TfLiteContext& context,
const TfLiteIntArray& tensor_list)
: VectorOfTensors<uint8_t>(context, tensor_list) {
for (int i = 0; i < tensor_list.size; ++i) {
TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
zero_point_.push_back(t->params.zero_point);
scale_.push_back(t->params.scale);
}
}
const float* scale() const { return scale_.data(); }
const int32_t* zero_point() const { return zero_point_.data(); }
private:
std::vector<int32_t> zero_point_;
std::vector<float> scale_;
};
// Writes randomly accessed values from `input` sequentially into `output`.
template <typename T>
class SequentialTensorWriter {
public:
SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) {
input_data_ = GetTensorData<T>(input);
output_ptr_ = GetTensorData<T>(output);
}
SequentialTensorWriter(const T* input_data, T* output_data)
: input_data_(input_data), output_ptr_(output_data) {}
void Write(int position) { *output_ptr_++ = input_data_[position]; }
void WriteN(int position, int len) {
memcpy(output_ptr_, &input_data_[position], sizeof(T) * len);
output_ptr_ += len;
}
private:
const T* input_data_;
T* output_ptr_;
};
// String ops are not yet supported on platforms w/ static memory.
#ifndef TF_LITE_STATIC_MEMORY
template <>
class SequentialTensorWriter<string> {
public:
@ -140,7 +41,6 @@ class SequentialTensorWriter<string> {
TfLiteTensor* output_;
DynamicBuffer buffer_;
};
#endif // TF_LITE_STATIC_MEMORY
} // namespace tflite

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
#include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/type_to_tflitetype.h"
namespace tflite {

View File

@ -200,17 +200,15 @@ tensorflow/lite/kernels/internal/reference/tanh.h \
tensorflow/lite/kernels/internal/cppmath.h \
tensorflow/lite/kernels/internal/max.h \
tensorflow/lite/kernels/internal/min.h \
tensorflow/lite/kernels/internal/portable_tensor.h \
tensorflow/lite/kernels/internal/strided_slice_logic.h \
tensorflow/lite/kernels/internal/tensor.h \
tensorflow/lite/kernels/internal/tensor_ctypes.h \
tensorflow/lite/kernels/internal/types.h \
tensorflow/lite/kernels/kernel_util.h \
tensorflow/lite/kernels/op_macros.h \
tensorflow/lite/kernels/padding.h \
tensorflow/lite/portable_type_to_tflitetype.h \
tensorflow/lite/schema/schema_generated.h \
tensorflow/lite/string_type.h \
tensorflow/lite/string_util.h \
tensorflow/lite/type_to_tflitetype.h \
tensorflow/lite/version.h
THIRD_PARTY_CC_HDRS := \

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_PORTABLE_TYPE_TO_TFLITETYPE_H_
#define TENSORFLOW_LITE_PORTABLE_TYPE_TO_TFLITETYPE_H_
// Most of the definitions have been moved to this subheader so that Micro
// can include it without relying on <string>, which isn't available on all
// platforms.
// Arduino build defines abs as a macro here. That is invalid C++, and breaks
// libc++'s <complex> header, undefine it.
#ifdef abs
#undef abs
#endif
#include <complex>
#include "tensorflow/lite/c/common.h"
namespace tflite {
// Map statically from a C++ type to a TfLiteType. Used in interpreter for
// safe casts.
// Example:
// typeToTfLiteType<bool>() -> kTfLiteBool
template <typename T>
constexpr TfLiteType typeToTfLiteType() {
return kTfLiteNoType;
}
// Map from TfLiteType to the corresponding C++ type.
// Example:
// TfLiteTypeToType<kTfLiteBool>::Type -> bool
template <TfLiteType TFLITE_TYPE_ENUM>
struct TfLiteTypeToType {}; // Specializations below
// Template specialization for both typeToTfLiteType and TfLiteTypeToType.
#define MATCH_TYPE_AND_TFLITE_TYPE(CPP_TYPE, TFLITE_TYPE_ENUM) \
template <> \
constexpr TfLiteType typeToTfLiteType<CPP_TYPE>() { \
return TFLITE_TYPE_ENUM; \
} \
template <> \
struct TfLiteTypeToType<TFLITE_TYPE_ENUM> { \
using Type = CPP_TYPE; \
}
// No string mapping is included here, since the TF Lite packed representation
// doesn't correspond to a C++ type well.
MATCH_TYPE_AND_TFLITE_TYPE(int, kTfLiteInt32);
MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16);
MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64);
MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32);
MATCH_TYPE_AND_TFLITE_TYPE(unsigned char, kTfLiteUInt8);
MATCH_TYPE_AND_TFLITE_TYPE(int8_t, kTfLiteInt8);
MATCH_TYPE_AND_TFLITE_TYPE(bool, kTfLiteBool);
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<float>, kTfLiteComplex64);
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<double>, kTfLiteComplex128);
MATCH_TYPE_AND_TFLITE_TYPE(TfLiteFloat16, kTfLiteFloat16);
MATCH_TYPE_AND_TFLITE_TYPE(double, kTfLiteFloat64);
} // namespace tflite
#endif // TENSORFLOW_LITE_PORTABLE_TYPE_TO_TFLITETYPE_H_

View File

@ -76,9 +76,6 @@ class DynamicBuffer {
// The function allocates space for the buffer but does NOT take ownership.
int WriteToBuffer(char** buffer);
// String tensors are not generally supported on platforms w/ static memory.
// TODO(b/156130024): Remove this guard after removing header from TFLM deps.
#ifndef TF_LITE_STATIC_MEMORY
// Fill content into a string tensor, with the given new_shape. The new shape
// must match the number of strings in this object. Caller relinquishes
// ownership of new_shape. If 'new_shape' is nullptr, keep the tensor's
@ -87,7 +84,6 @@ class DynamicBuffer {
// Fill content into a string tensor. Set shape to {num_strings}.
void WriteToTensorAsVector(TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
private:
// Data buffer to store contents of strings, not including headers.

View File

@ -15,56 +15,20 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_
#define TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_
// Arduino build defines abs as a macro here. That is invalid C++, and breaks
// libc++'s <complex> header, undefine it.
#ifdef abs
#undef abs
#endif
#include <complex>
#include <string>
#include "tensorflow/lite/c/common.h"
// Most of the definitions have been moved to this subheader so that Micro
// can include it without relying on <string>, which isn't available on all
// platforms.
#include "tensorflow/lite/portable_type_to_tflitetype.h"
namespace tflite {
// Map statically from a C++ type to a TfLiteType. Used in interpreter for
// safe casts.
// Example:
// typeToTfLiteType<bool>() -> kTfLiteBool
template <typename T>
constexpr TfLiteType typeToTfLiteType() {
return kTfLiteNoType;
}
// Map from TfLiteType to the corresponding C++ type.
// Example:
// TfLiteTypeToType<kTfLiteBool>::Type -> bool
template <TfLiteType TFLITE_TYPE_ENUM>
struct TfLiteTypeToType {}; // Specializations below
// Template specialization for both typeToTfLiteType and TfLiteTypeToType.
#define MATCH_TYPE_AND_TFLITE_TYPE(CPP_TYPE, TFLITE_TYPE_ENUM) \
template <> \
constexpr TfLiteType typeToTfLiteType<CPP_TYPE>() { \
return TFLITE_TYPE_ENUM; \
} \
template <> \
struct TfLiteTypeToType<TFLITE_TYPE_ENUM> { \
using Type = CPP_TYPE; \
}
MATCH_TYPE_AND_TFLITE_TYPE(int, kTfLiteInt32);
MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16);
MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64);
MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32);
MATCH_TYPE_AND_TFLITE_TYPE(unsigned char, kTfLiteUInt8);
MATCH_TYPE_AND_TFLITE_TYPE(int8_t, kTfLiteInt8);
MATCH_TYPE_AND_TFLITE_TYPE(bool, kTfLiteBool);
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<float>, kTfLiteComplex64);
MATCH_TYPE_AND_TFLITE_TYPE(std::complex<double>, kTfLiteComplex128);
// TODO(b/163167649): This string conversion means that only the first entry
// in a string tensor will be returned as a std::string, so it's deprecated.
MATCH_TYPE_AND_TFLITE_TYPE(std::string, kTfLiteString);
MATCH_TYPE_AND_TFLITE_TYPE(TfLiteFloat16, kTfLiteFloat16);
MATCH_TYPE_AND_TFLITE_TYPE(double, kTfLiteFloat64);
} // namespace tflite
#endif // TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_