Move GetInt8DataPtr function to kernel_util.h
PiperOrigin-RevId: 243924570
This commit is contained in:
parent
ffe054107f
commit
1ce55fb6ed
@ -281,6 +281,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":op_macros",
|
":op_macros",
|
||||||
"//tensorflow/lite/c:c_api_internal",
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"//tensorflow/lite/kernels:kernel_util",
|
||||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||||
],
|
],
|
||||||
|
@ -27,16 +27,6 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace rnn {
|
namespace rnn {
|
||||||
|
|
||||||
namespace {
|
|
||||||
int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
|
|
||||||
if (is_uint8) {
|
|
||||||
return reinterpret_cast<int8_t*>(tensor->data.uint8);
|
|
||||||
} else {
|
|
||||||
return tensor->data.int8;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
constexpr int kWeightsTensor = 1;
|
constexpr int kWeightsTensor = 1;
|
||||||
constexpr int kRecurrentWeightsTensor = 2;
|
constexpr int kRecurrentWeightsTensor = 2;
|
||||||
|
@ -31,18 +31,6 @@ namespace ops {
|
|||||||
namespace builtin {
|
namespace builtin {
|
||||||
namespace bidirectional_sequence_rnn {
|
namespace bidirectional_sequence_rnn {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
|
|
||||||
if (is_uint8) {
|
|
||||||
return reinterpret_cast<int8_t*>(tensor->data.uint8);
|
|
||||||
} else {
|
|
||||||
return tensor->data.int8;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
|
@ -65,6 +65,14 @@ inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
|
||||||
|
if (is_uint8) {
|
||||||
|
return reinterpret_cast<int8_t*>(tensor->data.uint8);
|
||||||
|
} else {
|
||||||
|
return tensor->data.int8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Determines whether tensor is constant.
|
// Determines whether tensor is constant.
|
||||||
inline bool IsConstantTensor(const TfLiteTensor* tensor) {
|
inline bool IsConstantTensor(const TfLiteTensor* tensor) {
|
||||||
return tensor->allocation_type == kTfLiteMmapRo;
|
return tensor->allocation_type == kTfLiteMmapRo;
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -857,14 +858,6 @@ inline void LstmStepWithAuxInput(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
|
|
||||||
if (is_uint8) {
|
|
||||||
return reinterpret_cast<int8_t*>(tensor->data.uint8);
|
|
||||||
} else {
|
|
||||||
return tensor->data.int8;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TfLiteStatus EvalFloat(
|
TfLiteStatus EvalFloat(
|
||||||
|
Loading…
Reference in New Issue
Block a user