Check shape of constant tensor for ADD

GPU only handles 1x1x...xn dimiensions tensors. Do not handle random
constants.

PiperOrigin-RevId: 313563512
Change-Id: Ifee00ccc2138b4aa1067d476f8f73e5c8cc1e19a
This commit is contained in:
Terry Heo 2020-05-28 04:32:13 -07:00 committed by TensorFlower Gardener
parent a2e1334b92
commit 3c9dfef469
3 changed files with 15 additions and 1 deletions

View File

@ -402,6 +402,13 @@ class AddOperationParser : public TFLiteOperationParser {
return absl::UnimplementedError("ADD requires two input tensors.");
}
// TODO(eignasheva): Add shapes check.
for (int i = 0; i < 2; i++) {
auto input = tflite::GetInput(context, tflite_node, i);
if (IsConstantTensor(input) && input->dims->size > 0) {
RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
}
}
TfLiteAddParams* tf_options = nullptr;
return RetrieveBuiltinData(tflite_node, &tf_options);
}

View File

@ -239,7 +239,7 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
return absl::OkStatus();
}
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions) {
if (dimensions->size <= 0) {
return absl::InvalidArgumentError("Dimension is empty.");
}
@ -249,6 +249,11 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
GetDimensionString(dimensions), " cannot be reduced to linear."));
}
}
return absl::OkStatus();
}
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
RETURN_IF_ERROR(CheckIfLinearConvertible(dimensions));
shape->v = dimensions->data[dimensions->size - 1];
return absl::OkStatus();
}

View File

@ -108,6 +108,8 @@ absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape);
absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions);
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape);
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape);