Check shape of constant tensor for ADD
GPU only handles 1x1x...xn dimiensions tensors. Do not handle random constants. PiperOrigin-RevId: 313563512 Change-Id: Ifee00ccc2138b4aa1067d476f8f73e5c8cc1e19a
This commit is contained in:
parent
a2e1334b92
commit
3c9dfef469
tensorflow/lite/delegates/gpu/common
@ -402,6 +402,13 @@ class AddOperationParser : public TFLiteOperationParser {
|
||||
return absl::UnimplementedError("ADD requires two input tensors.");
|
||||
}
|
||||
// TODO(eignasheva): Add shapes check.
|
||||
for (int i = 0; i < 2; i++) {
|
||||
auto input = tflite::GetInput(context, tflite_node, i);
|
||||
if (IsConstantTensor(input) && input->dims->size > 0) {
|
||||
RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteAddParams* tf_options = nullptr;
|
||||
return RetrieveBuiltinData(tflite_node, &tf_options);
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
|
||||
absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions) {
|
||||
if (dimensions->size <= 0) {
|
||||
return absl::InvalidArgumentError("Dimension is empty.");
|
||||
}
|
||||
@ -249,6 +249,11 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
|
||||
GetDimensionString(dimensions), " cannot be reduced to linear."));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
|
||||
RETURN_IF_ERROR(CheckIfLinearConvertible(dimensions));
|
||||
shape->v = dimensions->data[dimensions->size - 1];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -108,6 +108,8 @@ absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape);
|
||||
|
||||
absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape);
|
||||
|
Loading…
Reference in New Issue
Block a user