Move GetInt8DataPtr function to kernel_util.h

PiperOrigin-RevId: 243924570
This commit is contained in:
A. Unique TensorFlower 2019-04-16 19:40:50 -07:00 committed by TensorFlower Gardener
parent ffe054107f
commit 1ce55fb6ed
5 changed files with 10 additions and 30 deletions

View File

@ -281,6 +281,7 @@ cc_library(
deps = [ deps = [
":op_macros", ":op_macros",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:tensor_utils",
], ],

View File

@ -27,16 +27,6 @@ namespace ops {
namespace builtin { namespace builtin {
namespace rnn { namespace rnn {
namespace {
int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
if (is_uint8) {
return reinterpret_cast<int8_t*>(tensor->data.uint8);
} else {
return tensor->data.int8;
}
}
} // namespace
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1; constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2; constexpr int kRecurrentWeightsTensor = 2;

View File

@ -31,18 +31,6 @@ namespace ops {
namespace builtin { namespace builtin {
namespace bidirectional_sequence_rnn { namespace bidirectional_sequence_rnn {
namespace {
int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
if (is_uint8) {
return reinterpret_cast<int8_t*>(tensor->data.uint8);
} else {
return tensor->data.int8;
}
}
} // namespace
// LINT.IfChange // LINT.IfChange
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;

View File

@ -65,6 +65,14 @@ inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
return nullptr; return nullptr;
} }
inline int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
if (is_uint8) {
return reinterpret_cast<int8_t*>(tensor->data.uint8);
} else {
return tensor->data.int8;
}
}
// Determines whether tensor is constant. // Determines whether tensor is constant.
inline bool IsConstantTensor(const TfLiteTensor* tensor) { inline bool IsConstantTensor(const TfLiteTensor* tensor) {
return tensor->allocation_type == kTfLiteMmapRo; return tensor->allocation_type == kTfLiteMmapRo;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/op_macros.h"
namespace tflite { namespace tflite {
@ -857,14 +858,6 @@ inline void LstmStepWithAuxInput(
} }
} }
int8_t* GetInt8DataPtr(const TfLiteTensor* tensor, const bool is_uint8) {
if (is_uint8) {
return reinterpret_cast<int8_t*>(tensor->data.uint8);
} else {
return tensor->data.int8;
}
}
} // namespace } // namespace
TfLiteStatus EvalFloat( TfLiteStatus EvalFloat(