internal cleanup functions
PiperOrigin-RevId: 261820730
This commit is contained in:
parent
9239c61f20
commit
476fd9f8da
@ -31,8 +31,10 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kShapeTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteIntArray* output_shape) {
|
||||
TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*);
|
||||
|
||||
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteIntArray* output_shape = GetOutputShape(context, node);
|
||||
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
|
||||
scoped_output_shape(output_shape, TfLiteIntArrayFree);
|
||||
|
||||
@ -65,8 +67,8 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
|
||||
return context->ResizeTensor(context, output, scoped_output_shape.release());
|
||||
}
|
||||
|
||||
TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
||||
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
|
||||
@ -77,8 +79,8 @@ TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
|
||||
|
||||
// The function is returned above this line if the shape tensor is usable.
|
||||
@ -99,7 +101,7 @@ TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
|
||||
}
|
||||
|
||||
// Check if the shape tensor is valid. Shapes should be int32 vectors.
|
||||
bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
|
||||
inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
||||
return (shape->dims->size == 1 && shape->type == kTfLiteInt32);
|
||||
}
|
||||
@ -124,8 +126,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (output->type != kTfLiteString) {
|
||||
if (NumInputs(node) == 1 ||
|
||||
IsConstantTensor(GetInput(context, node, kShapeTensor))) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, ResizeOutput(context, node, GetOutputShape(context, node)));
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
|
||||
} else {
|
||||
SetTensorToDynamic(output);
|
||||
}
|
||||
@ -141,8 +142,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// a string tensor, or its shape cannot be calculated during Prepare(). In
|
||||
// either case, we now have all the information to calculate its shape.
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, ResizeOutput(context, node, GetOutputShape(context, node)));
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
|
||||
}
|
||||
|
||||
// Note that string tensors are always "dynamic" in the sense that their size
|
||||
|
Loading…
Reference in New Issue
Block a user