Teach TFLite model verifier about all supported types
PiperOrigin-RevId: 259073172
This commit is contained in:
		
							parent
							
								
									fdc106e412
								
							
						
					
					
						commit
						8f55026cc8
					
				| @ -91,7 +91,8 @@ cc_test( | ||||
|         "//tensorflow/core:framework_lite", | ||||
|         "//tensorflow/lite:framework", | ||||
|         "//tensorflow/lite:schema_fbs_version", | ||||
|         "//tensorflow/lite/c:c_api_internal", | ||||
|         "//tensorflow/lite:util", | ||||
|         "//tensorflow/lite/core/api", | ||||
|         "//tensorflow/lite/schema:schema_fbs", | ||||
|         "//tensorflow/lite/testing:util", | ||||
|         "@com_google_googletest//:gtest", | ||||
|  | ||||
| @ -130,20 +130,30 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, | ||||
|     case TensorType_FLOAT32: | ||||
|       bytes_required *= sizeof(float); | ||||
|       break; | ||||
|     case TensorType_INT8: | ||||
|       bytes_required *= sizeof(int8_t); | ||||
|       break; | ||||
|     case TensorType_UINT8: | ||||
|       bytes_required *= sizeof(uint8_t); | ||||
|     case TensorType_FLOAT16: | ||||
|       bytes_required *= sizeof(uint16_t); | ||||
|       break; | ||||
|     case TensorType_INT32: | ||||
|       bytes_required *= sizeof(int32_t); | ||||
|       break; | ||||
|     case TensorType_UINT8: | ||||
|       bytes_required *= sizeof(uint8_t); | ||||
|       break; | ||||
|     case TensorType_INT8: | ||||
|       bytes_required *= sizeof(int8_t); | ||||
|       break; | ||||
|     case TensorType_INT64: | ||||
|       bytes_required *= sizeof(int64_t); | ||||
|       break; | ||||
|     case TensorType_FLOAT16: | ||||
|       // FALLTHROUGH_INTENDED;
 | ||||
|     case TensorType_BOOL: | ||||
|       bytes_required *= sizeof(bool); | ||||
|       break; | ||||
|     case TensorType_INT16: | ||||
|       bytes_required *= sizeof(uint16_t); | ||||
|       break; | ||||
|     case TensorType_COMPLEX64: | ||||
|       bytes_required *= sizeof(std::complex<float>); | ||||
|       break; | ||||
|     default: | ||||
|       ReportError(error_reporter, "Tensor %s invalid type: %d", | ||||
|                   tensor.name()->c_str(), tensor.type()); | ||||
|  | ||||
| @ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #include "tensorflow/lite/tools/verifier.h" | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| @ -21,11 +23,12 @@ limitations under the License. | ||||
| #include <gtest/gtest.h> | ||||
| #include "tensorflow/core/framework/numeric_types.h" | ||||
| #include "tensorflow/lite/allocation.h" | ||||
| #include "tensorflow/lite/core/api/flatbuffer_conversions.h" | ||||
| #include "tensorflow/lite/error_reporter.h" | ||||
| #include "tensorflow/lite/op_resolver.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| #include "tensorflow/lite/testing/util.h" | ||||
| #include "tensorflow/lite/tools/verifier.h" | ||||
| #include "tensorflow/lite/util.h" | ||||
| #include "tensorflow/lite/version.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| @ -516,6 +519,42 @@ TEST(VerifyModel, OpWithOptionalTensor) { | ||||
|   EXPECT_EQ("", builder.GetErrorString()); | ||||
| } | ||||
| 
 | ||||
| TEST(VerifyModel, TypedTensorShapeMismatchWithTensorBufferSize) { | ||||
|   TfLiteFlatbufferModelBuilder builder; | ||||
|   for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX; | ||||
|        ++tensor_type) { | ||||
|     if (tensor_type == TensorType_STRING) continue; | ||||
|     builder.AddTensor({2, 3}, static_cast<TensorType>(tensor_type), | ||||
|                       {1, 2, 3, 4}, "input"); | ||||
|     builder.FinishModel({}, {}); | ||||
|     ASSERT_FALSE(builder.Verify()); | ||||
|     EXPECT_THAT( | ||||
|         builder.GetErrorString(), | ||||
|         ::testing::ContainsRegex("Tensor input requires .* bytes, but is " | ||||
|                                  "allocated with 4 bytes buffer")); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) { | ||||
|   TfLiteFlatbufferModelBuilder builder; | ||||
|   for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX; | ||||
|        ++tensor_type) { | ||||
|     if (tensor_type == TensorType_STRING) continue; | ||||
|     TfLiteType lite_type = kTfLiteNoType; | ||||
|     ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type), | ||||
|                                 &lite_type, /*error_reporter=*/nullptr), | ||||
|               kTfLiteOk); | ||||
|     size_t size_bytes = 0; | ||||
|     ASSERT_EQ(GetSizeOfType(/*context=*/nullptr, lite_type, &size_bytes), | ||||
|               kTfLiteOk); | ||||
|     std::vector<uint8_t> buffer(size_bytes); | ||||
|     builder.AddTensor({1}, static_cast<TensorType>(tensor_type), buffer, | ||||
|                       "input"); | ||||
|     builder.FinishModel({}, {}); | ||||
|     ASSERT_TRUE(builder.Verify()); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| // TODO(yichengfan): make up malicious files to test with.
 | ||||
| 
 | ||||
| }  // namespace tflite
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user