From 3c9dfef4696c821de6e02f3c676c0c356602894b Mon Sep 17 00:00:00 2001
From: Terry Heo <terryheo@google.com>
Date: Thu, 28 May 2020 04:32:13 -0700
Subject: [PATCH] Check shape of constant tensor for ADD

GPU only handles 1x1x...xn dimiensions tensors. Do not handle random
constants.

PiperOrigin-RevId: 313563512
Change-Id: Ifee00ccc2138b4aa1067d476f8f73e5c8cc1e19a
---
 tensorflow/lite/delegates/gpu/common/model_builder.cc      | 7 +++++++
 .../lite/delegates/gpu/common/model_builder_helper.cc      | 7 ++++++-
 .../lite/delegates/gpu/common/model_builder_helper.h       | 2 ++
 3 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index 061c65095eb..29d9813379e 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -402,6 +402,13 @@ class AddOperationParser : public TFLiteOperationParser {
       return absl::UnimplementedError("ADD requires two input tensors.");
     }
     // TODO(eignasheva): Add shapes check.
+    for (int i = 0; i < 2; i++) {
+      auto input = tflite::GetInput(context, tflite_node, i);
+      if (IsConstantTensor(input) && input->dims->size > 0) {
+        RETURN_IF_ERROR(CheckIfLinearConvertible(input->dims));
+      }
+    }
+
     TfLiteAddParams* tf_options = nullptr;
     return RetrieveBuiltinData(tflite_node, &tf_options);
   }
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
index 9a15f940fbd..a1705e6cf78 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
@@ -239,7 +239,7 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
   return absl::OkStatus();
 }
 
-absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
+absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions) {
   if (dimensions->size <= 0) {
     return absl::InvalidArgumentError("Dimension is empty.");
   }
@@ -249,6 +249,11 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
           GetDimensionString(dimensions), "  cannot be reduced to linear."));
     }
   }
+  return absl::OkStatus();
+}
+
+absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
+  RETURN_IF_ERROR(CheckIfLinearConvertible(dimensions));
   shape->v = dimensions->data[dimensions->size - 1];
   return absl::OkStatus();
 }
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h
index 9caa5630037..6cbfcd9e7d6 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h
+++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h
@@ -108,6 +108,8 @@ absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
 
 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape);
 
+absl::Status CheckIfLinearConvertible(const TfLiteIntArray* dimensions);
+
 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape);
 
 absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape);