internal cleanup functions
PiperOrigin-RevId: 261820730
This commit is contained in:
parent
9239c61f20
commit
476fd9f8da
@ -31,8 +31,10 @@ constexpr int kInputTensor = 0;
|
|||||||
constexpr int kShapeTensor = 1;
|
constexpr int kShapeTensor = 1;
|
||||||
constexpr int kOutputTensor = 0;
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
|
TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*);
|
||||||
TfLiteIntArray* output_shape) {
|
|
||||||
|
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteIntArray* output_shape = GetOutputShape(context, node);
|
||||||
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
|
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
|
||||||
scoped_output_shape(output_shape, TfLiteIntArrayFree);
|
scoped_output_shape(output_shape, TfLiteIntArrayFree);
|
||||||
|
|
||||||
@ -65,8 +67,8 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return context->ResizeTensor(context, output, scoped_output_shape.release());
|
return context->ResizeTensor(context, output, scoped_output_shape.release());
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
|
inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
|
||||||
TfLiteNode* node) {
|
TfLiteNode* node) {
|
||||||
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
||||||
|
|
||||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
|
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
|
||||||
@ -77,8 +79,8 @@ TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
|
|||||||
return output_shape;
|
return output_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
|
inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
|
||||||
TfLiteNode* node) {
|
TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
|
||||||
|
|
||||||
// The function is returned above this line if the shape tensor is usable.
|
// The function is returned above this line if the shape tensor is usable.
|
||||||
@ -99,7 +101,7 @@ TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the shape tensor is valid. Shapes should be int32 vectors.
|
// Check if the shape tensor is valid. Shapes should be int32 vectors.
|
||||||
bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
|
inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
|
||||||
return (shape->dims->size == 1 && shape->type == kTfLiteInt32);
|
return (shape->dims->size == 1 && shape->type == kTfLiteInt32);
|
||||||
}
|
}
|
||||||
@ -124,8 +126,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
if (output->type != kTfLiteString) {
|
if (output->type != kTfLiteString) {
|
||||||
if (NumInputs(node) == 1 ||
|
if (NumInputs(node) == 1 ||
|
||||||
IsConstantTensor(GetInput(context, node, kShapeTensor))) {
|
IsConstantTensor(GetInput(context, node, kShapeTensor))) {
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
|
||||||
context, ResizeOutput(context, node, GetOutputShape(context, node)));
|
|
||||||
} else {
|
} else {
|
||||||
SetTensorToDynamic(output);
|
SetTensorToDynamic(output);
|
||||||
}
|
}
|
||||||
@ -141,8 +142,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// a string tensor, or its shape cannot be calculated during Prepare(). In
|
// a string tensor, or its shape cannot be calculated during Prepare(). In
|
||||||
// either case, we now have all the information to calculate its shape.
|
// either case, we now have all the information to calculate its shape.
|
||||||
if (IsDynamicTensor(output)) {
|
if (IsDynamicTensor(output)) {
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
|
||||||
context, ResizeOutput(context, node, GetOutputShape(context, node)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that string tensors are always "dynamic" in the sense that their size
|
// Note that string tensors are always "dynamic" in the sense that their size
|
||||||
|
Loading…
Reference in New Issue
Block a user