Teach TFLite model verifier about all supported types

PiperOrigin-RevId: 259073172
This commit is contained in:
Jared Duke 2019-07-19 18:18:43 -07:00 committed by TensorFlower Gardener
parent fdc106e412
commit 8f55026cc8
3 changed files with 59 additions and 9 deletions

View File

@ -91,7 +91,8 @@ cc_test(
"//tensorflow/core:framework_lite", "//tensorflow/core:framework_lite",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite/c:c_api_internal", "//tensorflow/lite:util",
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util", "//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",

View File

@ -130,20 +130,30 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
case TensorType_FLOAT32: case TensorType_FLOAT32:
bytes_required *= sizeof(float); bytes_required *= sizeof(float);
break; break;
case TensorType_INT8: case TensorType_FLOAT16:
bytes_required *= sizeof(int8_t); bytes_required *= sizeof(uint16_t);
break;
case TensorType_UINT8:
bytes_required *= sizeof(uint8_t);
break; break;
case TensorType_INT32: case TensorType_INT32:
bytes_required *= sizeof(int32_t); bytes_required *= sizeof(int32_t);
break; break;
case TensorType_UINT8:
bytes_required *= sizeof(uint8_t);
break;
case TensorType_INT8:
bytes_required *= sizeof(int8_t);
break;
case TensorType_INT64: case TensorType_INT64:
bytes_required *= sizeof(int64_t); bytes_required *= sizeof(int64_t);
break; break;
case TensorType_FLOAT16: case TensorType_BOOL:
// FALLTHROUGH_INTENDED; bytes_required *= sizeof(bool);
break;
case TensorType_INT16:
bytes_required *= sizeof(uint16_t);
break;
case TensorType_COMPLEX64:
bytes_required *= sizeof(std::complex<float>);
break;
default: default:
ReportError(error_reporter, "Tensor %s invalid type: %d", ReportError(error_reporter, "Tensor %s invalid type: %d",
tensor.name()->c_str(), tensor.type()); tensor.name()->c_str(), tensor.type());

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/lite/tools/verifier.h"
#include <string> #include <string>
#include <vector> #include <vector>
@ -21,11 +23,12 @@ limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/lite/allocation.h" #include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/op_resolver.h" #include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/testing/util.h"
#include "tensorflow/lite/tools/verifier.h" #include "tensorflow/lite/util.h"
#include "tensorflow/lite/version.h" #include "tensorflow/lite/version.h"
namespace tflite { namespace tflite {
@ -516,6 +519,42 @@ TEST(VerifyModel, OpWithOptionalTensor) {
EXPECT_EQ("", builder.GetErrorString()); EXPECT_EQ("", builder.GetErrorString());
} }
TEST(VerifyModel, TypedTensorShapeMismatchWithTensorBufferSize) {
TfLiteFlatbufferModelBuilder builder;
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
++tensor_type) {
if (tensor_type == TensorType_STRING) continue;
builder.AddTensor({2, 3}, static_cast<TensorType>(tensor_type),
{1, 2, 3, 4}, "input");
builder.FinishModel({}, {});
ASSERT_FALSE(builder.Verify());
EXPECT_THAT(
builder.GetErrorString(),
::testing::ContainsRegex("Tensor input requires .* bytes, but is "
"allocated with 4 bytes buffer"));
}
}
TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) {
TfLiteFlatbufferModelBuilder builder;
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
++tensor_type) {
if (tensor_type == TensorType_STRING) continue;
TfLiteType lite_type = kTfLiteNoType;
ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type),
&lite_type, /*error_reporter=*/nullptr),
kTfLiteOk);
size_t size_bytes = 0;
ASSERT_EQ(GetSizeOfType(/*context=*/nullptr, lite_type, &size_bytes),
kTfLiteOk);
std::vector<uint8_t> buffer(size_bytes);
builder.AddTensor({1}, static_cast<TensorType>(tensor_type), buffer,
"input");
builder.FinishModel({}, {});
ASSERT_TRUE(builder.Verify());
}
}
// TODO(yichengfan): make up malicious files to test with. // TODO(yichengfan): make up malicious files to test with.
} // namespace tflite } // namespace tflite